initial project setup with README and ignore
This commit is contained in:
400
app/services/ml/id3_classifier.py
Normal file
400
app/services/ml/id3_classifier.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""
|
||||
ID3 Classifier - Production Grade
|
||||
|
||||
Improvements over v1:
|
||||
- Chi-squared pruning to prevent overfitting on sparse branches
|
||||
- Confidence scores on every prediction (Laplace smoothed)
|
||||
- Gain-ratio variant for high-cardinality features
|
||||
- Serialization (to_dict / from_dict / to_json / from_json)
|
||||
- Per-feature importance scores
|
||||
- Full prediction audit trail via explain()
|
||||
- min_samples_split and min_info_gain stopping criteria
|
||||
"""
|
||||
|
||||
import math
|
||||
import json
|
||||
import logging
|
||||
from collections import Counter
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ID3Classifier:
|
||||
"""
|
||||
ID3 decision tree (entropy / information-gain splitting).
|
||||
All predict* methods work even if the model has never been trained -
|
||||
they return safe defaults rather than raising.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_depth: int = 6,
|
||||
min_samples_split: int = 5,
|
||||
min_info_gain: float = 0.001,
|
||||
use_gain_ratio: bool = False,
|
||||
chi2_pruning: bool = True,
|
||||
):
|
||||
self.max_depth = max_depth
|
||||
self.min_samples_split = min_samples_split
|
||||
self.min_info_gain = min_info_gain
|
||||
self.use_gain_ratio = use_gain_ratio
|
||||
self.chi2_pruning = chi2_pruning
|
||||
|
||||
self.tree: Any = None
|
||||
self.features: List[str] = []
|
||||
self.target: str = ""
|
||||
self.classes_: List[str] = []
|
||||
self.feature_importances_: Dict[str, float] = {}
|
||||
self.feature_values: Dict[str, List[str]] = {} # unique values seen per feature
|
||||
self._n_samples: int = 0
|
||||
self._total_gain: Dict[str, float] = {}
|
||||
|
||||
# ------------------------------------------------------------------ train
|
||||
|
||||
def train(self, data: List[Dict[str, Any]], target: str, features: List[str]) -> None:
|
||||
if not data:
|
||||
logger.warning("ID3: train() called with empty data.")
|
||||
return
|
||||
|
||||
self.target = target
|
||||
self.features = list(features)
|
||||
self.classes_ = sorted({str(row.get(target)) for row in data})
|
||||
self._total_gain = {f: 0.0 for f in features}
|
||||
self._n_samples = len(data)
|
||||
|
||||
# Collect unique values per feature for dashboard display
|
||||
self.feature_values = {
|
||||
f: sorted({str(row.get(f)) for row in data if row.get(f) is not None})
|
||||
for f in features
|
||||
}
|
||||
|
||||
self.tree = self._build_tree(data, list(features), target, depth=0)
|
||||
|
||||
if self.chi2_pruning:
|
||||
self.tree = self._prune(self.tree, data, target)
|
||||
|
||||
total_gain = sum(self._total_gain.values()) or 1.0
|
||||
self.feature_importances_ = {
|
||||
f: round(v / total_gain, 4) for f, v in self._total_gain.items()
|
||||
}
|
||||
logger.info(
|
||||
f"ID3: trained on {len(data)} samples | "
|
||||
f"classes={self.classes_} | importances={self.feature_importances_}"
|
||||
)
|
||||
|
||||
# ----------------------------------------------------------- predict API
|
||||
|
||||
def predict(self, sample: Dict[str, Any]) -> Tuple[str, float]:
|
||||
"""Return (label, confidence 0-1). Safe to call before training."""
|
||||
if self.tree is None:
|
||||
return "Unknown", 0.0
|
||||
label, proba = self._classify(self.tree, sample, [])
|
||||
confidence = proba.get(str(label), 0.0) if isinstance(proba, dict) else 1.0
|
||||
return str(label), round(confidence, 4)
|
||||
|
||||
def predict_proba(self, sample: Dict[str, Any]) -> Dict[str, float]:
|
||||
"""Full class probability distribution."""
|
||||
if self.tree is None:
|
||||
return {}
|
||||
_, proba = self._classify(self.tree, sample, [])
|
||||
return proba if isinstance(proba, dict) else {str(proba): 1.0}
|
||||
|
||||
def explain(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Human-readable decision path for audit / dashboard display."""
|
||||
if self.tree is None:
|
||||
return {"prediction": "Unknown", "confidence": 0.0, "decision_path": []}
|
||||
path: List[str] = []
|
||||
label, proba = self._classify(self.tree, sample, path)
|
||||
return {
|
||||
"prediction": str(label),
|
||||
"confidence": round(proba.get(str(label), 1.0), 4),
|
||||
"probabilities": proba,
|
||||
"decision_path": path,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------- serialisation
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"tree": self.tree,
|
||||
"features": self.features,
|
||||
"target": self.target,
|
||||
"classes": self.classes_,
|
||||
"feature_importances": self.feature_importances_,
|
||||
"feature_values": self.feature_values,
|
||||
"n_samples": self._n_samples,
|
||||
"params": {
|
||||
"max_depth": self.max_depth,
|
||||
"min_samples_split": self.min_samples_split,
|
||||
"min_info_gain": self.min_info_gain,
|
||||
"use_gain_ratio": self.use_gain_ratio,
|
||||
"chi2_pruning": self.chi2_pruning,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: Dict[str, Any]) -> "ID3Classifier":
|
||||
p = d.get("params", {})
|
||||
obj = cls(
|
||||
max_depth=p.get("max_depth", 6),
|
||||
min_samples_split=p.get("min_samples_split", 5),
|
||||
min_info_gain=p.get("min_info_gain", 0.001),
|
||||
use_gain_ratio=p.get("use_gain_ratio", False),
|
||||
chi2_pruning=p.get("chi2_pruning", True),
|
||||
)
|
||||
obj.tree = d["tree"]
|
||||
obj.features = d["features"]
|
||||
obj.target = d["target"]
|
||||
obj.classes_ = d["classes"]
|
||||
obj.feature_importances_ = d.get("feature_importances", {})
|
||||
obj.feature_values = d.get("feature_values", {})
|
||||
obj._n_samples = d.get("n_samples", 0)
|
||||
return obj
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(self.to_dict(), indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, s: str) -> "ID3Classifier":
|
||||
return cls.from_dict(json.loads(s))
|
||||
|
||||
def summary(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"n_samples": self._n_samples,
|
||||
"n_classes": len(self.classes_),
|
||||
"classes": self.classes_,
|
||||
"n_features": len(self.features),
|
||||
"feature_importances": self.feature_importances_,
|
||||
"feature_values": self.feature_values,
|
||||
"trained": self.tree is not None,
|
||||
}
|
||||
|
||||
@property
|
||||
def classes(self) -> List[str]:
|
||||
"""Alias for classes_ for compatibility."""
|
||||
return self.classes_
|
||||
|
||||
def get_tree_rules(self) -> List[str]:
|
||||
"""Extract human-readable if/then rules from the trained tree."""
|
||||
rules: List[str] = []
|
||||
if self.tree is None:
|
||||
return rules
|
||||
self._extract_rules(self.tree, [], rules)
|
||||
return rules
|
||||
|
||||
def _extract_rules(self, node: Any, conditions: List[str], rules: List[str]) -> None:
|
||||
"""Recursively walk the tree and collect decision paths as strings."""
|
||||
if not isinstance(node, dict):
|
||||
return
|
||||
if node.get("__leaf__"):
|
||||
label = node.get("__label__", "?")
|
||||
proba = node.get("__proba__", {})
|
||||
conf = proba.get(str(label), 0.0)
|
||||
prefix = " AND ".join(conditions) if conditions else "(root)"
|
||||
rules.append(f"{prefix} => {label} ({conf:.0%})")
|
||||
return
|
||||
feature = node.get("__feature__", "?")
|
||||
for val, child in node.get("__branches__", {}).items():
|
||||
self._extract_rules(child, conditions + [f"{feature}={val}"], rules)
|
||||
|
||||
# --------------------------------------------------------- tree building
|
||||
|
||||
def _build_tree(
|
||||
self,
|
||||
data: List[Dict[str, Any]],
|
||||
features: List[str],
|
||||
target: str,
|
||||
depth: int,
|
||||
) -> Any:
|
||||
counts = Counter(str(row.get(target)) for row in data)
|
||||
|
||||
# Pure node
|
||||
if len(counts) == 1:
|
||||
return self._make_leaf(data, target)
|
||||
|
||||
# Stopping criteria
|
||||
if not features or depth >= self.max_depth or len(data) < self.min_samples_split:
|
||||
return self._make_leaf(data, target)
|
||||
|
||||
best_f, best_gain = self._best_split(data, features, target)
|
||||
if best_f is None or best_gain < self.min_info_gain:
|
||||
return self._make_leaf(data, target)
|
||||
|
||||
self._total_gain[best_f] = self._total_gain.get(best_f, 0.0) + best_gain
|
||||
|
||||
remaining = [f for f in features if f != best_f]
|
||||
node = {
|
||||
"__feature__": best_f,
|
||||
"__gain__": round(best_gain, 6),
|
||||
"__n__": len(data),
|
||||
"__branches__": {},
|
||||
}
|
||||
for val in {row.get(best_f) for row in data}:
|
||||
subset = [r for r in data if r.get(best_f) == val]
|
||||
node["__branches__"][str(val)] = self._build_tree(
|
||||
subset, remaining, target, depth + 1
|
||||
)
|
||||
return node
|
||||
|
||||
def _make_leaf(self, data: List[Dict[str, Any]], target: str) -> Dict[str, Any]:
|
||||
counts = Counter(str(row.get(target)) for row in data)
|
||||
total = len(data)
|
||||
k = len(self.classes_) or 1
|
||||
# Laplace smoothing
|
||||
proba = {
|
||||
cls: round((counts.get(cls, 0) + 1) / (total + k), 4)
|
||||
for cls in self.classes_
|
||||
}
|
||||
label = max(proba, key=proba.get)
|
||||
return {"__leaf__": True, "__label__": label, "__proba__": proba, "__n__": total}
|
||||
|
||||
# ---------------------------------------------------------- splitting
|
||||
|
||||
def _best_split(
|
||||
self, data: List[Dict[str, Any]], features: List[str], target: str
|
||||
) -> Tuple[Optional[str], float]:
|
||||
base_e = self._entropy(data, target)
|
||||
best_f, best_gain = None, -1.0
|
||||
for f in features:
|
||||
gain = self._info_gain(data, f, target, base_e)
|
||||
if self.use_gain_ratio:
|
||||
si = self._split_info(data, f)
|
||||
gain = gain / si if si > 0 else 0.0
|
||||
if gain > best_gain:
|
||||
best_gain = gain
|
||||
best_f = f
|
||||
return best_f, best_gain
|
||||
|
||||
# ----------------------------------------------------------- pruning
|
||||
|
||||
def _prune(self, node: Any, data: List[Dict[str, Any]], target: str) -> Any:
|
||||
if not isinstance(node, dict) or node.get("__leaf__"):
|
||||
return node
|
||||
feature = node["__feature__"]
|
||||
# Recurse children first
|
||||
for val in list(node["__branches__"].keys()):
|
||||
subset = [r for r in data if str(r.get(feature)) == str(val)]
|
||||
node["__branches__"][val] = self._prune(node["__branches__"][val], subset, target)
|
||||
# Chi-squared test: if split is not significant, collapse to leaf
|
||||
if not self._chi2_significant(data, feature, target):
|
||||
return self._make_leaf(data, target)
|
||||
return node
|
||||
|
||||
def _chi2_significant(
|
||||
self, data: List[Dict[str, Any]], feature: str, target: str
|
||||
) -> bool:
|
||||
classes = self.classes_
|
||||
feature_vals = list({str(r.get(feature)) for r in data})
|
||||
if not classes or len(feature_vals) < 2:
|
||||
return False
|
||||
total = len(data)
|
||||
class_totals = Counter(str(r.get(target)) for r in data)
|
||||
chi2 = 0.0
|
||||
for val in feature_vals:
|
||||
subset = [r for r in data if str(r.get(feature)) == val]
|
||||
n_val = len(subset)
|
||||
val_counts = Counter(str(r.get(target)) for r in subset)
|
||||
for cls in classes:
|
||||
observed = val_counts.get(cls, 0)
|
||||
expected = (n_val * class_totals.get(cls, 0)) / total
|
||||
if expected > 0:
|
||||
chi2 += (observed - expected) ** 2 / expected
|
||||
df = (len(feature_vals) - 1) * (len(classes) - 1)
|
||||
if df <= 0:
|
||||
return False
|
||||
# Critical values at p=0.05
|
||||
crit_table = {1: 3.841, 2: 5.991, 3: 7.815, 4: 9.488, 5: 11.070, 6: 12.592}
|
||||
crit = crit_table.get(df, 3.841 * df)
|
||||
return chi2 > crit
|
||||
|
||||
# ---------------------------------------------------------- classify
|
||||
|
||||
def _classify(
|
||||
self, node: Any, row: Dict[str, Any], path: List[str]
|
||||
) -> Tuple[Any, Any]:
|
||||
if not isinstance(node, dict):
|
||||
return node, {str(node): 1.0}
|
||||
if node.get("__leaf__"):
|
||||
label = node["__label__"]
|
||||
proba = node["__proba__"]
|
||||
path.append(f"predict={label} (p={proba.get(label, 0):.2f})")
|
||||
return label, proba
|
||||
|
||||
feature = node["__feature__"]
|
||||
value = str(row.get(feature, ""))
|
||||
path.append(f"{feature}={value}")
|
||||
|
||||
branches = node["__branches__"]
|
||||
if value in branches:
|
||||
return self._classify(branches[value], row, path)
|
||||
|
||||
# Unseen value: weighted vote from all leaf children
|
||||
all_proba: Counter = Counter()
|
||||
total_n = 0
|
||||
for child in branches.values():
|
||||
if isinstance(child, dict) and child.get("__leaf__"):
|
||||
n = child.get("__n__", 1)
|
||||
total_n += n
|
||||
for cls, p in child.get("__proba__", {}).items():
|
||||
all_proba[cls] += p * n
|
||||
|
||||
if not total_n:
|
||||
fallback = self.classes_[0] if self.classes_ else "Unknown"
|
||||
path.append(f"unseen fallback: {fallback}")
|
||||
return fallback, {fallback: 1.0}
|
||||
|
||||
proba = {cls: round(v / total_n, 4) for cls, v in all_proba.items()}
|
||||
label = max(proba, key=proba.get)
|
||||
path.append(f"weighted vote: {label}")
|
||||
return label, proba
|
||||
|
||||
# ---------------------------------------------------------- entropy math
|
||||
|
||||
def _entropy(self, data: List[Dict[str, Any]], target: str) -> float:
|
||||
if not data:
|
||||
return 0.0
|
||||
counts = Counter(str(row.get(target)) for row in data)
|
||||
total = len(data)
|
||||
return -sum((c / total) * math.log2(c / total) for c in counts.values() if c > 0)
|
||||
|
||||
def _info_gain(
|
||||
self,
|
||||
data: List[Dict[str, Any]],
|
||||
feature: str,
|
||||
target: str,
|
||||
base_entropy: Optional[float] = None,
|
||||
) -> float:
|
||||
if base_entropy is None:
|
||||
base_entropy = self._entropy(data, target)
|
||||
total = len(data)
|
||||
buckets: Dict[Any, list] = {}
|
||||
for row in data:
|
||||
buckets.setdefault(row.get(feature), []).append(row)
|
||||
weighted = sum(
|
||||
(len(sub) / total) * self._entropy(sub, target) for sub in buckets.values()
|
||||
)
|
||||
return base_entropy - weighted
|
||||
|
||||
def _split_info(self, data: List[Dict[str, Any]], feature: str) -> float:
|
||||
total = len(data)
|
||||
counts = Counter(row.get(feature) for row in data)
|
||||
return -sum((c / total) * math.log2(c / total) for c in counts.values() if c > 0)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ factory
|
||||
|
||||
def get_behavior_model(
|
||||
max_depth: int = 5,
|
||||
min_samples_split: int = 8,
|
||||
min_info_gain: float = 0.005,
|
||||
use_gain_ratio: bool = True,
|
||||
chi2_pruning: bool = True,
|
||||
) -> ID3Classifier:
|
||||
return ID3Classifier(
|
||||
max_depth=max_depth,
|
||||
min_samples_split=min_samples_split,
|
||||
min_info_gain=min_info_gain,
|
||||
use_gain_ratio=use_gain_ratio,
|
||||
chi2_pruning=chi2_pruning,
|
||||
)
|
||||
Reference in New Issue
Block a user