""" ID3 Classifier - Production Grade Improvements over v1: - Chi-squared pruning to prevent overfitting on sparse branches - Confidence scores on every prediction (Laplace smoothed) - Gain-ratio variant for high-cardinality features - Serialization (to_dict / from_dict / to_json / from_json) - Per-feature importance scores - Full prediction audit trail via explain() - min_samples_split and min_info_gain stopping criteria """ import math import json import logging from collections import Counter from typing import Any, Dict, List, Optional, Tuple logger = logging.getLogger(__name__) class ID3Classifier: """ ID3 decision tree (entropy / information-gain splitting). All predict* methods work even if the model has never been trained - they return safe defaults rather than raising. """ def __init__( self, max_depth: int = 6, min_samples_split: int = 5, min_info_gain: float = 0.001, use_gain_ratio: bool = False, chi2_pruning: bool = True, ): self.max_depth = max_depth self.min_samples_split = min_samples_split self.min_info_gain = min_info_gain self.use_gain_ratio = use_gain_ratio self.chi2_pruning = chi2_pruning self.tree: Any = None self.features: List[str] = [] self.target: str = "" self.classes_: List[str] = [] self.feature_importances_: Dict[str, float] = {} self.feature_values: Dict[str, List[str]] = {} # unique values seen per feature self._n_samples: int = 0 self._total_gain: Dict[str, float] = {} # ------------------------------------------------------------------ train def train(self, data: List[Dict[str, Any]], target: str, features: List[str]) -> None: if not data: logger.warning("ID3: train() called with empty data.") return self.target = target self.features = list(features) self.classes_ = sorted({str(row.get(target)) for row in data}) self._total_gain = {f: 0.0 for f in features} self._n_samples = len(data) # Collect unique values per feature for dashboard display self.feature_values = { f: sorted({str(row.get(f)) for row in data if row.get(f) is not None}) for f in features } self.tree = self._build_tree(data, list(features), target, depth=0) if self.chi2_pruning: self.tree = self._prune(self.tree, data, target) total_gain = sum(self._total_gain.values()) or 1.0 self.feature_importances_ = { f: round(v / total_gain, 4) for f, v in self._total_gain.items() } logger.info( f"ID3: trained on {len(data)} samples | " f"classes={self.classes_} | importances={self.feature_importances_}" ) # ----------------------------------------------------------- predict API def predict(self, sample: Dict[str, Any]) -> Tuple[str, float]: """Return (label, confidence 0-1). Safe to call before training.""" if self.tree is None: return "Unknown", 0.0 label, proba = self._classify(self.tree, sample, []) confidence = proba.get(str(label), 0.0) if isinstance(proba, dict) else 1.0 return str(label), round(confidence, 4) def predict_proba(self, sample: Dict[str, Any]) -> Dict[str, float]: """Full class probability distribution.""" if self.tree is None: return {} _, proba = self._classify(self.tree, sample, []) return proba if isinstance(proba, dict) else {str(proba): 1.0} def explain(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Human-readable decision path for audit / dashboard display.""" if self.tree is None: return {"prediction": "Unknown", "confidence": 0.0, "decision_path": []} path: List[str] = [] label, proba = self._classify(self.tree, sample, path) return { "prediction": str(label), "confidence": round(proba.get(str(label), 1.0), 4), "probabilities": proba, "decision_path": path, } # ---------------------------------------------------------- serialisation def to_dict(self) -> Dict[str, Any]: return { "tree": self.tree, "features": self.features, "target": self.target, "classes": self.classes_, "feature_importances": self.feature_importances_, "feature_values": self.feature_values, "n_samples": self._n_samples, "params": { "max_depth": self.max_depth, "min_samples_split": self.min_samples_split, "min_info_gain": self.min_info_gain, "use_gain_ratio": self.use_gain_ratio, "chi2_pruning": self.chi2_pruning, }, } @classmethod def from_dict(cls, d: Dict[str, Any]) -> "ID3Classifier": p = d.get("params", {}) obj = cls( max_depth=p.get("max_depth", 6), min_samples_split=p.get("min_samples_split", 5), min_info_gain=p.get("min_info_gain", 0.001), use_gain_ratio=p.get("use_gain_ratio", False), chi2_pruning=p.get("chi2_pruning", True), ) obj.tree = d["tree"] obj.features = d["features"] obj.target = d["target"] obj.classes_ = d["classes"] obj.feature_importances_ = d.get("feature_importances", {}) obj.feature_values = d.get("feature_values", {}) obj._n_samples = d.get("n_samples", 0) return obj def to_json(self) -> str: return json.dumps(self.to_dict(), indent=2) @classmethod def from_json(cls, s: str) -> "ID3Classifier": return cls.from_dict(json.loads(s)) def summary(self) -> Dict[str, Any]: return { "n_samples": self._n_samples, "n_classes": len(self.classes_), "classes": self.classes_, "n_features": len(self.features), "feature_importances": self.feature_importances_, "feature_values": self.feature_values, "trained": self.tree is not None, } @property def classes(self) -> List[str]: """Alias for classes_ for compatibility.""" return self.classes_ def get_tree_rules(self) -> List[str]: """Extract human-readable if/then rules from the trained tree.""" rules: List[str] = [] if self.tree is None: return rules self._extract_rules(self.tree, [], rules) return rules def _extract_rules(self, node: Any, conditions: List[str], rules: List[str]) -> None: """Recursively walk the tree and collect decision paths as strings.""" if not isinstance(node, dict): return if node.get("__leaf__"): label = node.get("__label__", "?") proba = node.get("__proba__", {}) conf = proba.get(str(label), 0.0) prefix = " AND ".join(conditions) if conditions else "(root)" rules.append(f"{prefix} => {label} ({conf:.0%})") return feature = node.get("__feature__", "?") for val, child in node.get("__branches__", {}).items(): self._extract_rules(child, conditions + [f"{feature}={val}"], rules) # --------------------------------------------------------- tree building def _build_tree( self, data: List[Dict[str, Any]], features: List[str], target: str, depth: int, ) -> Any: counts = Counter(str(row.get(target)) for row in data) # Pure node if len(counts) == 1: return self._make_leaf(data, target) # Stopping criteria if not features or depth >= self.max_depth or len(data) < self.min_samples_split: return self._make_leaf(data, target) best_f, best_gain = self._best_split(data, features, target) if best_f is None or best_gain < self.min_info_gain: return self._make_leaf(data, target) self._total_gain[best_f] = self._total_gain.get(best_f, 0.0) + best_gain remaining = [f for f in features if f != best_f] node = { "__feature__": best_f, "__gain__": round(best_gain, 6), "__n__": len(data), "__branches__": {}, } for val in {row.get(best_f) for row in data}: subset = [r for r in data if r.get(best_f) == val] node["__branches__"][str(val)] = self._build_tree( subset, remaining, target, depth + 1 ) return node def _make_leaf(self, data: List[Dict[str, Any]], target: str) -> Dict[str, Any]: counts = Counter(str(row.get(target)) for row in data) total = len(data) k = len(self.classes_) or 1 # Laplace smoothing proba = { cls: round((counts.get(cls, 0) + 1) / (total + k), 4) for cls in self.classes_ } label = max(proba, key=proba.get) return {"__leaf__": True, "__label__": label, "__proba__": proba, "__n__": total} # ---------------------------------------------------------- splitting def _best_split( self, data: List[Dict[str, Any]], features: List[str], target: str ) -> Tuple[Optional[str], float]: base_e = self._entropy(data, target) best_f, best_gain = None, -1.0 for f in features: gain = self._info_gain(data, f, target, base_e) if self.use_gain_ratio: si = self._split_info(data, f) gain = gain / si if si > 0 else 0.0 if gain > best_gain: best_gain = gain best_f = f return best_f, best_gain # ----------------------------------------------------------- pruning def _prune(self, node: Any, data: List[Dict[str, Any]], target: str) -> Any: if not isinstance(node, dict) or node.get("__leaf__"): return node feature = node["__feature__"] # Recurse children first for val in list(node["__branches__"].keys()): subset = [r for r in data if str(r.get(feature)) == str(val)] node["__branches__"][val] = self._prune(node["__branches__"][val], subset, target) # Chi-squared test: if split is not significant, collapse to leaf if not self._chi2_significant(data, feature, target): return self._make_leaf(data, target) return node def _chi2_significant( self, data: List[Dict[str, Any]], feature: str, target: str ) -> bool: classes = self.classes_ feature_vals = list({str(r.get(feature)) for r in data}) if not classes or len(feature_vals) < 2: return False total = len(data) class_totals = Counter(str(r.get(target)) for r in data) chi2 = 0.0 for val in feature_vals: subset = [r for r in data if str(r.get(feature)) == val] n_val = len(subset) val_counts = Counter(str(r.get(target)) for r in subset) for cls in classes: observed = val_counts.get(cls, 0) expected = (n_val * class_totals.get(cls, 0)) / total if expected > 0: chi2 += (observed - expected) ** 2 / expected df = (len(feature_vals) - 1) * (len(classes) - 1) if df <= 0: return False # Critical values at p=0.05 crit_table = {1: 3.841, 2: 5.991, 3: 7.815, 4: 9.488, 5: 11.070, 6: 12.592} crit = crit_table.get(df, 3.841 * df) return chi2 > crit # ---------------------------------------------------------- classify def _classify( self, node: Any, row: Dict[str, Any], path: List[str] ) -> Tuple[Any, Any]: if not isinstance(node, dict): return node, {str(node): 1.0} if node.get("__leaf__"): label = node["__label__"] proba = node["__proba__"] path.append(f"predict={label} (p={proba.get(label, 0):.2f})") return label, proba feature = node["__feature__"] value = str(row.get(feature, "")) path.append(f"{feature}={value}") branches = node["__branches__"] if value in branches: return self._classify(branches[value], row, path) # Unseen value: weighted vote from all leaf children all_proba: Counter = Counter() total_n = 0 for child in branches.values(): if isinstance(child, dict) and child.get("__leaf__"): n = child.get("__n__", 1) total_n += n for cls, p in child.get("__proba__", {}).items(): all_proba[cls] += p * n if not total_n: fallback = self.classes_[0] if self.classes_ else "Unknown" path.append(f"unseen fallback: {fallback}") return fallback, {fallback: 1.0} proba = {cls: round(v / total_n, 4) for cls, v in all_proba.items()} label = max(proba, key=proba.get) path.append(f"weighted vote: {label}") return label, proba # ---------------------------------------------------------- entropy math def _entropy(self, data: List[Dict[str, Any]], target: str) -> float: if not data: return 0.0 counts = Counter(str(row.get(target)) for row in data) total = len(data) return -sum((c / total) * math.log2(c / total) for c in counts.values() if c > 0) def _info_gain( self, data: List[Dict[str, Any]], feature: str, target: str, base_entropy: Optional[float] = None, ) -> float: if base_entropy is None: base_entropy = self._entropy(data, target) total = len(data) buckets: Dict[Any, list] = {} for row in data: buckets.setdefault(row.get(feature), []).append(row) weighted = sum( (len(sub) / total) * self._entropy(sub, target) for sub in buckets.values() ) return base_entropy - weighted def _split_info(self, data: List[Dict[str, Any]], feature: str) -> float: total = len(data) counts = Counter(row.get(feature) for row in data) return -sum((c / total) * math.log2(c / total) for c in counts.values() if c > 0) # ------------------------------------------------------------------ factory def get_behavior_model( max_depth: int = 5, min_samples_split: int = 8, min_info_gain: float = 0.005, use_gain_ratio: bool = True, chi2_pruning: bool = True, ) -> ID3Classifier: return ID3Classifier( max_depth=max_depth, min_samples_split=min_samples_split, min_info_gain=min_info_gain, use_gain_ratio=use_gain_ratio, chi2_pruning=chi2_pruning, )