401 lines
15 KiB
Python
401 lines
15 KiB
Python
"""
|
|
ID3 Classifier - Production Grade
|
|
|
|
Improvements over v1:
|
|
- Chi-squared pruning to prevent overfitting on sparse branches
|
|
- Confidence scores on every prediction (Laplace smoothed)
|
|
- Gain-ratio variant for high-cardinality features
|
|
- Serialization (to_dict / from_dict / to_json / from_json)
|
|
- Per-feature importance scores
|
|
- Full prediction audit trail via explain()
|
|
- min_samples_split and min_info_gain stopping criteria
|
|
"""
|
|
|
|
import math
|
|
import json
|
|
import logging
|
|
from collections import Counter
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ID3Classifier:
|
|
"""
|
|
ID3 decision tree (entropy / information-gain splitting).
|
|
All predict* methods work even if the model has never been trained -
|
|
they return safe defaults rather than raising.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
max_depth: int = 6,
|
|
min_samples_split: int = 5,
|
|
min_info_gain: float = 0.001,
|
|
use_gain_ratio: bool = False,
|
|
chi2_pruning: bool = True,
|
|
):
|
|
self.max_depth = max_depth
|
|
self.min_samples_split = min_samples_split
|
|
self.min_info_gain = min_info_gain
|
|
self.use_gain_ratio = use_gain_ratio
|
|
self.chi2_pruning = chi2_pruning
|
|
|
|
self.tree: Any = None
|
|
self.features: List[str] = []
|
|
self.target: str = ""
|
|
self.classes_: List[str] = []
|
|
self.feature_importances_: Dict[str, float] = {}
|
|
self.feature_values: Dict[str, List[str]] = {} # unique values seen per feature
|
|
self._n_samples: int = 0
|
|
self._total_gain: Dict[str, float] = {}
|
|
|
|
# ------------------------------------------------------------------ train
|
|
|
|
def train(self, data: List[Dict[str, Any]], target: str, features: List[str]) -> None:
|
|
if not data:
|
|
logger.warning("ID3: train() called with empty data.")
|
|
return
|
|
|
|
self.target = target
|
|
self.features = list(features)
|
|
self.classes_ = sorted({str(row.get(target)) for row in data})
|
|
self._total_gain = {f: 0.0 for f in features}
|
|
self._n_samples = len(data)
|
|
|
|
# Collect unique values per feature for dashboard display
|
|
self.feature_values = {
|
|
f: sorted({str(row.get(f)) for row in data if row.get(f) is not None})
|
|
for f in features
|
|
}
|
|
|
|
self.tree = self._build_tree(data, list(features), target, depth=0)
|
|
|
|
if self.chi2_pruning:
|
|
self.tree = self._prune(self.tree, data, target)
|
|
|
|
total_gain = sum(self._total_gain.values()) or 1.0
|
|
self.feature_importances_ = {
|
|
f: round(v / total_gain, 4) for f, v in self._total_gain.items()
|
|
}
|
|
logger.info(
|
|
f"ID3: trained on {len(data)} samples | "
|
|
f"classes={self.classes_} | importances={self.feature_importances_}"
|
|
)
|
|
|
|
# ----------------------------------------------------------- predict API
|
|
|
|
def predict(self, sample: Dict[str, Any]) -> Tuple[str, float]:
|
|
"""Return (label, confidence 0-1). Safe to call before training."""
|
|
if self.tree is None:
|
|
return "Unknown", 0.0
|
|
label, proba = self._classify(self.tree, sample, [])
|
|
confidence = proba.get(str(label), 0.0) if isinstance(proba, dict) else 1.0
|
|
return str(label), round(confidence, 4)
|
|
|
|
def predict_proba(self, sample: Dict[str, Any]) -> Dict[str, float]:
|
|
"""Full class probability distribution."""
|
|
if self.tree is None:
|
|
return {}
|
|
_, proba = self._classify(self.tree, sample, [])
|
|
return proba if isinstance(proba, dict) else {str(proba): 1.0}
|
|
|
|
def explain(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Human-readable decision path for audit / dashboard display."""
|
|
if self.tree is None:
|
|
return {"prediction": "Unknown", "confidence": 0.0, "decision_path": []}
|
|
path: List[str] = []
|
|
label, proba = self._classify(self.tree, sample, path)
|
|
return {
|
|
"prediction": str(label),
|
|
"confidence": round(proba.get(str(label), 1.0), 4),
|
|
"probabilities": proba,
|
|
"decision_path": path,
|
|
}
|
|
|
|
# ---------------------------------------------------------- serialisation
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"tree": self.tree,
|
|
"features": self.features,
|
|
"target": self.target,
|
|
"classes": self.classes_,
|
|
"feature_importances": self.feature_importances_,
|
|
"feature_values": self.feature_values,
|
|
"n_samples": self._n_samples,
|
|
"params": {
|
|
"max_depth": self.max_depth,
|
|
"min_samples_split": self.min_samples_split,
|
|
"min_info_gain": self.min_info_gain,
|
|
"use_gain_ratio": self.use_gain_ratio,
|
|
"chi2_pruning": self.chi2_pruning,
|
|
},
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, d: Dict[str, Any]) -> "ID3Classifier":
|
|
p = d.get("params", {})
|
|
obj = cls(
|
|
max_depth=p.get("max_depth", 6),
|
|
min_samples_split=p.get("min_samples_split", 5),
|
|
min_info_gain=p.get("min_info_gain", 0.001),
|
|
use_gain_ratio=p.get("use_gain_ratio", False),
|
|
chi2_pruning=p.get("chi2_pruning", True),
|
|
)
|
|
obj.tree = d["tree"]
|
|
obj.features = d["features"]
|
|
obj.target = d["target"]
|
|
obj.classes_ = d["classes"]
|
|
obj.feature_importances_ = d.get("feature_importances", {})
|
|
obj.feature_values = d.get("feature_values", {})
|
|
obj._n_samples = d.get("n_samples", 0)
|
|
return obj
|
|
|
|
def to_json(self) -> str:
|
|
return json.dumps(self.to_dict(), indent=2)
|
|
|
|
@classmethod
|
|
def from_json(cls, s: str) -> "ID3Classifier":
|
|
return cls.from_dict(json.loads(s))
|
|
|
|
def summary(self) -> Dict[str, Any]:
|
|
return {
|
|
"n_samples": self._n_samples,
|
|
"n_classes": len(self.classes_),
|
|
"classes": self.classes_,
|
|
"n_features": len(self.features),
|
|
"feature_importances": self.feature_importances_,
|
|
"feature_values": self.feature_values,
|
|
"trained": self.tree is not None,
|
|
}
|
|
|
|
@property
|
|
def classes(self) -> List[str]:
|
|
"""Alias for classes_ for compatibility."""
|
|
return self.classes_
|
|
|
|
def get_tree_rules(self) -> List[str]:
|
|
"""Extract human-readable if/then rules from the trained tree."""
|
|
rules: List[str] = []
|
|
if self.tree is None:
|
|
return rules
|
|
self._extract_rules(self.tree, [], rules)
|
|
return rules
|
|
|
|
def _extract_rules(self, node: Any, conditions: List[str], rules: List[str]) -> None:
|
|
"""Recursively walk the tree and collect decision paths as strings."""
|
|
if not isinstance(node, dict):
|
|
return
|
|
if node.get("__leaf__"):
|
|
label = node.get("__label__", "?")
|
|
proba = node.get("__proba__", {})
|
|
conf = proba.get(str(label), 0.0)
|
|
prefix = " AND ".join(conditions) if conditions else "(root)"
|
|
rules.append(f"{prefix} => {label} ({conf:.0%})")
|
|
return
|
|
feature = node.get("__feature__", "?")
|
|
for val, child in node.get("__branches__", {}).items():
|
|
self._extract_rules(child, conditions + [f"{feature}={val}"], rules)
|
|
|
|
# --------------------------------------------------------- tree building
|
|
|
|
def _build_tree(
|
|
self,
|
|
data: List[Dict[str, Any]],
|
|
features: List[str],
|
|
target: str,
|
|
depth: int,
|
|
) -> Any:
|
|
counts = Counter(str(row.get(target)) for row in data)
|
|
|
|
# Pure node
|
|
if len(counts) == 1:
|
|
return self._make_leaf(data, target)
|
|
|
|
# Stopping criteria
|
|
if not features or depth >= self.max_depth or len(data) < self.min_samples_split:
|
|
return self._make_leaf(data, target)
|
|
|
|
best_f, best_gain = self._best_split(data, features, target)
|
|
if best_f is None or best_gain < self.min_info_gain:
|
|
return self._make_leaf(data, target)
|
|
|
|
self._total_gain[best_f] = self._total_gain.get(best_f, 0.0) + best_gain
|
|
|
|
remaining = [f for f in features if f != best_f]
|
|
node = {
|
|
"__feature__": best_f,
|
|
"__gain__": round(best_gain, 6),
|
|
"__n__": len(data),
|
|
"__branches__": {},
|
|
}
|
|
for val in {row.get(best_f) for row in data}:
|
|
subset = [r for r in data if r.get(best_f) == val]
|
|
node["__branches__"][str(val)] = self._build_tree(
|
|
subset, remaining, target, depth + 1
|
|
)
|
|
return node
|
|
|
|
def _make_leaf(self, data: List[Dict[str, Any]], target: str) -> Dict[str, Any]:
|
|
counts = Counter(str(row.get(target)) for row in data)
|
|
total = len(data)
|
|
k = len(self.classes_) or 1
|
|
# Laplace smoothing
|
|
proba = {
|
|
cls: round((counts.get(cls, 0) + 1) / (total + k), 4)
|
|
for cls in self.classes_
|
|
}
|
|
label = max(proba, key=proba.get)
|
|
return {"__leaf__": True, "__label__": label, "__proba__": proba, "__n__": total}
|
|
|
|
# ---------------------------------------------------------- splitting
|
|
|
|
def _best_split(
|
|
self, data: List[Dict[str, Any]], features: List[str], target: str
|
|
) -> Tuple[Optional[str], float]:
|
|
base_e = self._entropy(data, target)
|
|
best_f, best_gain = None, -1.0
|
|
for f in features:
|
|
gain = self._info_gain(data, f, target, base_e)
|
|
if self.use_gain_ratio:
|
|
si = self._split_info(data, f)
|
|
gain = gain / si if si > 0 else 0.0
|
|
if gain > best_gain:
|
|
best_gain = gain
|
|
best_f = f
|
|
return best_f, best_gain
|
|
|
|
# ----------------------------------------------------------- pruning
|
|
|
|
def _prune(self, node: Any, data: List[Dict[str, Any]], target: str) -> Any:
|
|
if not isinstance(node, dict) or node.get("__leaf__"):
|
|
return node
|
|
feature = node["__feature__"]
|
|
# Recurse children first
|
|
for val in list(node["__branches__"].keys()):
|
|
subset = [r for r in data if str(r.get(feature)) == str(val)]
|
|
node["__branches__"][val] = self._prune(node["__branches__"][val], subset, target)
|
|
# Chi-squared test: if split is not significant, collapse to leaf
|
|
if not self._chi2_significant(data, feature, target):
|
|
return self._make_leaf(data, target)
|
|
return node
|
|
|
|
def _chi2_significant(
|
|
self, data: List[Dict[str, Any]], feature: str, target: str
|
|
) -> bool:
|
|
classes = self.classes_
|
|
feature_vals = list({str(r.get(feature)) for r in data})
|
|
if not classes or len(feature_vals) < 2:
|
|
return False
|
|
total = len(data)
|
|
class_totals = Counter(str(r.get(target)) for r in data)
|
|
chi2 = 0.0
|
|
for val in feature_vals:
|
|
subset = [r for r in data if str(r.get(feature)) == val]
|
|
n_val = len(subset)
|
|
val_counts = Counter(str(r.get(target)) for r in subset)
|
|
for cls in classes:
|
|
observed = val_counts.get(cls, 0)
|
|
expected = (n_val * class_totals.get(cls, 0)) / total
|
|
if expected > 0:
|
|
chi2 += (observed - expected) ** 2 / expected
|
|
df = (len(feature_vals) - 1) * (len(classes) - 1)
|
|
if df <= 0:
|
|
return False
|
|
# Critical values at p=0.05
|
|
crit_table = {1: 3.841, 2: 5.991, 3: 7.815, 4: 9.488, 5: 11.070, 6: 12.592}
|
|
crit = crit_table.get(df, 3.841 * df)
|
|
return chi2 > crit
|
|
|
|
# ---------------------------------------------------------- classify
|
|
|
|
def _classify(
|
|
self, node: Any, row: Dict[str, Any], path: List[str]
|
|
) -> Tuple[Any, Any]:
|
|
if not isinstance(node, dict):
|
|
return node, {str(node): 1.0}
|
|
if node.get("__leaf__"):
|
|
label = node["__label__"]
|
|
proba = node["__proba__"]
|
|
path.append(f"predict={label} (p={proba.get(label, 0):.2f})")
|
|
return label, proba
|
|
|
|
feature = node["__feature__"]
|
|
value = str(row.get(feature, ""))
|
|
path.append(f"{feature}={value}")
|
|
|
|
branches = node["__branches__"]
|
|
if value in branches:
|
|
return self._classify(branches[value], row, path)
|
|
|
|
# Unseen value: weighted vote from all leaf children
|
|
all_proba: Counter = Counter()
|
|
total_n = 0
|
|
for child in branches.values():
|
|
if isinstance(child, dict) and child.get("__leaf__"):
|
|
n = child.get("__n__", 1)
|
|
total_n += n
|
|
for cls, p in child.get("__proba__", {}).items():
|
|
all_proba[cls] += p * n
|
|
|
|
if not total_n:
|
|
fallback = self.classes_[0] if self.classes_ else "Unknown"
|
|
path.append(f"unseen fallback: {fallback}")
|
|
return fallback, {fallback: 1.0}
|
|
|
|
proba = {cls: round(v / total_n, 4) for cls, v in all_proba.items()}
|
|
label = max(proba, key=proba.get)
|
|
path.append(f"weighted vote: {label}")
|
|
return label, proba
|
|
|
|
# ---------------------------------------------------------- entropy math
|
|
|
|
def _entropy(self, data: List[Dict[str, Any]], target: str) -> float:
|
|
if not data:
|
|
return 0.0
|
|
counts = Counter(str(row.get(target)) for row in data)
|
|
total = len(data)
|
|
return -sum((c / total) * math.log2(c / total) for c in counts.values() if c > 0)
|
|
|
|
def _info_gain(
|
|
self,
|
|
data: List[Dict[str, Any]],
|
|
feature: str,
|
|
target: str,
|
|
base_entropy: Optional[float] = None,
|
|
) -> float:
|
|
if base_entropy is None:
|
|
base_entropy = self._entropy(data, target)
|
|
total = len(data)
|
|
buckets: Dict[Any, list] = {}
|
|
for row in data:
|
|
buckets.setdefault(row.get(feature), []).append(row)
|
|
weighted = sum(
|
|
(len(sub) / total) * self._entropy(sub, target) for sub in buckets.values()
|
|
)
|
|
return base_entropy - weighted
|
|
|
|
def _split_info(self, data: List[Dict[str, Any]], feature: str) -> float:
|
|
total = len(data)
|
|
counts = Counter(row.get(feature) for row in data)
|
|
return -sum((c / total) * math.log2(c / total) for c in counts.values() if c > 0)
|
|
|
|
|
|
# ------------------------------------------------------------------ factory
|
|
|
|
def get_behavior_model(
|
|
max_depth: int = 5,
|
|
min_samples_split: int = 8,
|
|
min_info_gain: float = 0.005,
|
|
use_gain_ratio: bool = True,
|
|
chi2_pruning: bool = True,
|
|
) -> ID3Classifier:
|
|
return ID3Classifier(
|
|
max_depth=max_depth,
|
|
min_samples_split=min_samples_split,
|
|
min_info_gain=min_info_gain,
|
|
use_gain_ratio=use_gain_ratio,
|
|
chi2_pruning=chi2_pruning,
|
|
)
|