Files
routesapi/app/services/ml/id3_classifier.py

401 lines
15 KiB
Python

"""
ID3 Classifier - Production Grade
Improvements over v1:
- Chi-squared pruning to prevent overfitting on sparse branches
- Confidence scores on every prediction (Laplace smoothed)
- Gain-ratio variant for high-cardinality features
- Serialization (to_dict / from_dict / to_json / from_json)
- Per-feature importance scores
- Full prediction audit trail via explain()
- min_samples_split and min_info_gain stopping criteria
"""
import math
import json
import logging
from collections import Counter
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
class ID3Classifier:
"""
ID3 decision tree (entropy / information-gain splitting).
All predict* methods work even if the model has never been trained -
they return safe defaults rather than raising.
"""
def __init__(
self,
max_depth: int = 6,
min_samples_split: int = 5,
min_info_gain: float = 0.001,
use_gain_ratio: bool = False,
chi2_pruning: bool = True,
):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.min_info_gain = min_info_gain
self.use_gain_ratio = use_gain_ratio
self.chi2_pruning = chi2_pruning
self.tree: Any = None
self.features: List[str] = []
self.target: str = ""
self.classes_: List[str] = []
self.feature_importances_: Dict[str, float] = {}
self.feature_values: Dict[str, List[str]] = {} # unique values seen per feature
self._n_samples: int = 0
self._total_gain: Dict[str, float] = {}
# ------------------------------------------------------------------ train
def train(self, data: List[Dict[str, Any]], target: str, features: List[str]) -> None:
if not data:
logger.warning("ID3: train() called with empty data.")
return
self.target = target
self.features = list(features)
self.classes_ = sorted({str(row.get(target)) for row in data})
self._total_gain = {f: 0.0 for f in features}
self._n_samples = len(data)
# Collect unique values per feature for dashboard display
self.feature_values = {
f: sorted({str(row.get(f)) for row in data if row.get(f) is not None})
for f in features
}
self.tree = self._build_tree(data, list(features), target, depth=0)
if self.chi2_pruning:
self.tree = self._prune(self.tree, data, target)
total_gain = sum(self._total_gain.values()) or 1.0
self.feature_importances_ = {
f: round(v / total_gain, 4) for f, v in self._total_gain.items()
}
logger.info(
f"ID3: trained on {len(data)} samples | "
f"classes={self.classes_} | importances={self.feature_importances_}"
)
# ----------------------------------------------------------- predict API
def predict(self, sample: Dict[str, Any]) -> Tuple[str, float]:
"""Return (label, confidence 0-1). Safe to call before training."""
if self.tree is None:
return "Unknown", 0.0
label, proba = self._classify(self.tree, sample, [])
confidence = proba.get(str(label), 0.0) if isinstance(proba, dict) else 1.0
return str(label), round(confidence, 4)
def predict_proba(self, sample: Dict[str, Any]) -> Dict[str, float]:
"""Full class probability distribution."""
if self.tree is None:
return {}
_, proba = self._classify(self.tree, sample, [])
return proba if isinstance(proba, dict) else {str(proba): 1.0}
def explain(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Human-readable decision path for audit / dashboard display."""
if self.tree is None:
return {"prediction": "Unknown", "confidence": 0.0, "decision_path": []}
path: List[str] = []
label, proba = self._classify(self.tree, sample, path)
return {
"prediction": str(label),
"confidence": round(proba.get(str(label), 1.0), 4),
"probabilities": proba,
"decision_path": path,
}
# ---------------------------------------------------------- serialisation
def to_dict(self) -> Dict[str, Any]:
return {
"tree": self.tree,
"features": self.features,
"target": self.target,
"classes": self.classes_,
"feature_importances": self.feature_importances_,
"feature_values": self.feature_values,
"n_samples": self._n_samples,
"params": {
"max_depth": self.max_depth,
"min_samples_split": self.min_samples_split,
"min_info_gain": self.min_info_gain,
"use_gain_ratio": self.use_gain_ratio,
"chi2_pruning": self.chi2_pruning,
},
}
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> "ID3Classifier":
p = d.get("params", {})
obj = cls(
max_depth=p.get("max_depth", 6),
min_samples_split=p.get("min_samples_split", 5),
min_info_gain=p.get("min_info_gain", 0.001),
use_gain_ratio=p.get("use_gain_ratio", False),
chi2_pruning=p.get("chi2_pruning", True),
)
obj.tree = d["tree"]
obj.features = d["features"]
obj.target = d["target"]
obj.classes_ = d["classes"]
obj.feature_importances_ = d.get("feature_importances", {})
obj.feature_values = d.get("feature_values", {})
obj._n_samples = d.get("n_samples", 0)
return obj
def to_json(self) -> str:
return json.dumps(self.to_dict(), indent=2)
@classmethod
def from_json(cls, s: str) -> "ID3Classifier":
return cls.from_dict(json.loads(s))
def summary(self) -> Dict[str, Any]:
return {
"n_samples": self._n_samples,
"n_classes": len(self.classes_),
"classes": self.classes_,
"n_features": len(self.features),
"feature_importances": self.feature_importances_,
"feature_values": self.feature_values,
"trained": self.tree is not None,
}
@property
def classes(self) -> List[str]:
"""Alias for classes_ for compatibility."""
return self.classes_
def get_tree_rules(self) -> List[str]:
"""Extract human-readable if/then rules from the trained tree."""
rules: List[str] = []
if self.tree is None:
return rules
self._extract_rules(self.tree, [], rules)
return rules
def _extract_rules(self, node: Any, conditions: List[str], rules: List[str]) -> None:
"""Recursively walk the tree and collect decision paths as strings."""
if not isinstance(node, dict):
return
if node.get("__leaf__"):
label = node.get("__label__", "?")
proba = node.get("__proba__", {})
conf = proba.get(str(label), 0.0)
prefix = " AND ".join(conditions) if conditions else "(root)"
rules.append(f"{prefix} => {label} ({conf:.0%})")
return
feature = node.get("__feature__", "?")
for val, child in node.get("__branches__", {}).items():
self._extract_rules(child, conditions + [f"{feature}={val}"], rules)
# --------------------------------------------------------- tree building
def _build_tree(
self,
data: List[Dict[str, Any]],
features: List[str],
target: str,
depth: int,
) -> Any:
counts = Counter(str(row.get(target)) for row in data)
# Pure node
if len(counts) == 1:
return self._make_leaf(data, target)
# Stopping criteria
if not features or depth >= self.max_depth or len(data) < self.min_samples_split:
return self._make_leaf(data, target)
best_f, best_gain = self._best_split(data, features, target)
if best_f is None or best_gain < self.min_info_gain:
return self._make_leaf(data, target)
self._total_gain[best_f] = self._total_gain.get(best_f, 0.0) + best_gain
remaining = [f for f in features if f != best_f]
node = {
"__feature__": best_f,
"__gain__": round(best_gain, 6),
"__n__": len(data),
"__branches__": {},
}
for val in {row.get(best_f) for row in data}:
subset = [r for r in data if r.get(best_f) == val]
node["__branches__"][str(val)] = self._build_tree(
subset, remaining, target, depth + 1
)
return node
def _make_leaf(self, data: List[Dict[str, Any]], target: str) -> Dict[str, Any]:
counts = Counter(str(row.get(target)) for row in data)
total = len(data)
k = len(self.classes_) or 1
# Laplace smoothing
proba = {
cls: round((counts.get(cls, 0) + 1) / (total + k), 4)
for cls in self.classes_
}
label = max(proba, key=proba.get)
return {"__leaf__": True, "__label__": label, "__proba__": proba, "__n__": total}
# ---------------------------------------------------------- splitting
def _best_split(
self, data: List[Dict[str, Any]], features: List[str], target: str
) -> Tuple[Optional[str], float]:
base_e = self._entropy(data, target)
best_f, best_gain = None, -1.0
for f in features:
gain = self._info_gain(data, f, target, base_e)
if self.use_gain_ratio:
si = self._split_info(data, f)
gain = gain / si if si > 0 else 0.0
if gain > best_gain:
best_gain = gain
best_f = f
return best_f, best_gain
# ----------------------------------------------------------- pruning
def _prune(self, node: Any, data: List[Dict[str, Any]], target: str) -> Any:
if not isinstance(node, dict) or node.get("__leaf__"):
return node
feature = node["__feature__"]
# Recurse children first
for val in list(node["__branches__"].keys()):
subset = [r for r in data if str(r.get(feature)) == str(val)]
node["__branches__"][val] = self._prune(node["__branches__"][val], subset, target)
# Chi-squared test: if split is not significant, collapse to leaf
if not self._chi2_significant(data, feature, target):
return self._make_leaf(data, target)
return node
def _chi2_significant(
self, data: List[Dict[str, Any]], feature: str, target: str
) -> bool:
classes = self.classes_
feature_vals = list({str(r.get(feature)) for r in data})
if not classes or len(feature_vals) < 2:
return False
total = len(data)
class_totals = Counter(str(r.get(target)) for r in data)
chi2 = 0.0
for val in feature_vals:
subset = [r for r in data if str(r.get(feature)) == val]
n_val = len(subset)
val_counts = Counter(str(r.get(target)) for r in subset)
for cls in classes:
observed = val_counts.get(cls, 0)
expected = (n_val * class_totals.get(cls, 0)) / total
if expected > 0:
chi2 += (observed - expected) ** 2 / expected
df = (len(feature_vals) - 1) * (len(classes) - 1)
if df <= 0:
return False
# Critical values at p=0.05
crit_table = {1: 3.841, 2: 5.991, 3: 7.815, 4: 9.488, 5: 11.070, 6: 12.592}
crit = crit_table.get(df, 3.841 * df)
return chi2 > crit
# ---------------------------------------------------------- classify
def _classify(
self, node: Any, row: Dict[str, Any], path: List[str]
) -> Tuple[Any, Any]:
if not isinstance(node, dict):
return node, {str(node): 1.0}
if node.get("__leaf__"):
label = node["__label__"]
proba = node["__proba__"]
path.append(f"predict={label} (p={proba.get(label, 0):.2f})")
return label, proba
feature = node["__feature__"]
value = str(row.get(feature, ""))
path.append(f"{feature}={value}")
branches = node["__branches__"]
if value in branches:
return self._classify(branches[value], row, path)
# Unseen value: weighted vote from all leaf children
all_proba: Counter = Counter()
total_n = 0
for child in branches.values():
if isinstance(child, dict) and child.get("__leaf__"):
n = child.get("__n__", 1)
total_n += n
for cls, p in child.get("__proba__", {}).items():
all_proba[cls] += p * n
if not total_n:
fallback = self.classes_[0] if self.classes_ else "Unknown"
path.append(f"unseen fallback: {fallback}")
return fallback, {fallback: 1.0}
proba = {cls: round(v / total_n, 4) for cls, v in all_proba.items()}
label = max(proba, key=proba.get)
path.append(f"weighted vote: {label}")
return label, proba
# ---------------------------------------------------------- entropy math
def _entropy(self, data: List[Dict[str, Any]], target: str) -> float:
if not data:
return 0.0
counts = Counter(str(row.get(target)) for row in data)
total = len(data)
return -sum((c / total) * math.log2(c / total) for c in counts.values() if c > 0)
def _info_gain(
self,
data: List[Dict[str, Any]],
feature: str,
target: str,
base_entropy: Optional[float] = None,
) -> float:
if base_entropy is None:
base_entropy = self._entropy(data, target)
total = len(data)
buckets: Dict[Any, list] = {}
for row in data:
buckets.setdefault(row.get(feature), []).append(row)
weighted = sum(
(len(sub) / total) * self._entropy(sub, target) for sub in buckets.values()
)
return base_entropy - weighted
def _split_info(self, data: List[Dict[str, Any]], feature: str) -> float:
total = len(data)
counts = Counter(row.get(feature) for row in data)
return -sum((c / total) * math.log2(c / total) for c in counts.values() if c > 0)
# ------------------------------------------------------------------ factory
def get_behavior_model(
max_depth: int = 5,
min_samples_split: int = 8,
min_info_gain: float = 0.005,
use_gain_ratio: bool = True,
chi2_pruning: bool = True,
) -> ID3Classifier:
return ID3Classifier(
max_depth=max_depth,
min_samples_split=min_samples_split,
min_info_gain=min_info_gain,
use_gain_ratio=use_gain_ratio,
chi2_pruning=chi2_pruning,
)