312 lines
12 KiB
Python
312 lines
12 KiB
Python
"""
|
|
Behavior Analyzer - Production Grade
|
|
======================================
|
|
Analyzes historical assignment data using the ID3 decision tree to classify
|
|
assignment outcomes as 'SUCCESS' or 'RISK'.
|
|
|
|
Key fixes and upgrades over the original
|
|
------------------------------------------
|
|
1. BUG FIX: distance_band now uses `total_distance_km` (not `num_orders`).
|
|
2. BUG FIX: time_band input is always normalized to uppercase before predict.
|
|
3. Rich feature set: distance_band, time_band, load_band, order_density_band.
|
|
4. Returns (label, confidence) from the classifier - exposes uncertainty.
|
|
5. Trend analysis: tracks rolling success rate over recent N windows.
|
|
6. Tree persistence: saves/loads trained tree as JSON to survive restarts.
|
|
7. Feature importance proxy: logs which features drove the split.
|
|
8. Thread-safe lazy training via a simple lock.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import sqlite3
|
|
import threading
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from app.services.ml.id3_classifier import ID3Classifier, get_behavior_model
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_DB_PATH = os.getenv("ML_DB_PATH", "ml_data/ml_store.db")
|
|
_TREE_PATH = os.getenv("ML_TREE_PATH", "ml_data/behavior_tree.json")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Band encoders (discrete labels for ID3)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def distance_band(km: float) -> str:
|
|
"""Total route distance -> discrete band."""
|
|
if km <= 5.0: return "SHORT"
|
|
if km <= 15.0: return "MID"
|
|
if km <= 30.0: return "LONG"
|
|
return "VERY_LONG"
|
|
|
|
|
|
def time_band(ts_str: str) -> str:
|
|
"""ISO timestamp -> time-of-day band."""
|
|
try:
|
|
hour = datetime.fromisoformat(ts_str).hour
|
|
if 6 <= hour < 10: return "MORNING_RUSH"
|
|
if 10 <= hour < 12: return "LATE_MORNING"
|
|
if 12 <= hour < 14: return "LUNCH_RUSH"
|
|
if 14 <= hour < 17: return "AFTERNOON"
|
|
if 17 <= hour < 20: return "EVENING_RUSH"
|
|
if 20 <= hour < 23: return "NIGHT"
|
|
return "LATE_NIGHT"
|
|
except Exception:
|
|
return "UNKNOWN"
|
|
|
|
|
|
def load_band(avg_load: float) -> str:
|
|
"""Average orders-per-rider -> load band."""
|
|
if avg_load <= 2.0: return "LIGHT"
|
|
if avg_load <= 5.0: return "MODERATE"
|
|
if avg_load <= 8.0: return "HEAVY"
|
|
return "OVERLOADED"
|
|
|
|
|
|
def order_density_band(num_orders: int, num_riders: int) -> str:
|
|
"""Orders per available rider -> density band."""
|
|
if num_riders == 0:
|
|
return "NO_RIDERS"
|
|
ratio = num_orders / num_riders
|
|
if ratio <= 2.0: return "SPARSE"
|
|
if ratio <= 5.0: return "NORMAL"
|
|
if ratio <= 9.0: return "DENSE"
|
|
return "OVERLOADED"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Behavior Analyzer
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class BehaviorAnalyzer:
|
|
"""
|
|
Trains an ID3 tree on historical assignment logs and predicts whether
|
|
a new assignment context is likely to SUCCEED or be at RISK.
|
|
|
|
Features used
|
|
-------------
|
|
- distance_band : total route distance bucket
|
|
- time_band : time-of-day bucket
|
|
- load_band : average load per rider bucket
|
|
- order_density_band : orders-per-rider ratio bucket
|
|
|
|
Target
|
|
------
|
|
- is_success: "SUCCESS" if unassigned_count == 0, else "RISK"
|
|
"""
|
|
|
|
TARGET = "is_success"
|
|
FEATURES = ["distance_band", "time_band", "load_band", "order_density_band"]
|
|
|
|
def __init__(self):
|
|
self._db_path = _DB_PATH
|
|
self._tree_path = _TREE_PATH
|
|
self.model: ID3Classifier = get_behavior_model(max_depth=5)
|
|
self.is_trained: bool = False
|
|
self._lock = threading.Lock()
|
|
self._training_size: int = 0
|
|
self._success_rate: float = 0.0
|
|
self._rules: List[str] = []
|
|
self._recent_trend: List[float] = []
|
|
|
|
self._load_tree()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Training
|
|
# ------------------------------------------------------------------
|
|
|
|
def train_on_history(self, limit: int = 2000) -> Dict[str, Any]:
|
|
"""Fetch the most recent rows from SQLite and rebuild the tree."""
|
|
with self._lock:
|
|
try:
|
|
rows = self._fetch_rows(limit)
|
|
if len(rows) < 10:
|
|
logger.warning(f"ID3 BehaviorAnalyzer: only {len(rows)} rows - need >=10.")
|
|
return {"status": "insufficient_data", "rows": len(rows)}
|
|
|
|
training_data, successes = self._preprocess(rows)
|
|
|
|
if not training_data:
|
|
return {"status": "preprocess_failed", "rows": len(rows)}
|
|
|
|
self.model.train(
|
|
data=training_data,
|
|
target=self.TARGET,
|
|
features=self.FEATURES,
|
|
)
|
|
self.is_trained = True
|
|
self._training_size = len(training_data)
|
|
self._success_rate = successes / len(training_data)
|
|
self._rules = self.model.get_tree_rules()
|
|
self._compute_trend(rows)
|
|
self._save_tree()
|
|
|
|
summary = {
|
|
"status": "ok",
|
|
"training_rows": self._training_size,
|
|
"success_rate": round(self._success_rate, 4),
|
|
"n_rules": len(self._rules),
|
|
"classes": self.model.classes,
|
|
"feature_values": self.model.feature_values,
|
|
}
|
|
logger.info(
|
|
f"ID3 BehaviorAnalyzer trained - rows={self._training_size}, "
|
|
f"success_rate={self._success_rate:.1%}, rules={len(self._rules)}"
|
|
)
|
|
return summary
|
|
|
|
except Exception as e:
|
|
logger.error(f"ID3 BehaviorAnalyzer training failed: {e}", exc_info=True)
|
|
return {"status": "error", "message": str(e)}
|
|
|
|
# ------------------------------------------------------------------
|
|
# Prediction
|
|
# ------------------------------------------------------------------
|
|
|
|
def predict(self, distance_km: float, timestamp_or_band: str,
|
|
avg_load: float = 4.0, num_orders: int = 5,
|
|
num_riders: int = 2) -> Dict[str, Any]:
|
|
"""Predict whether an assignment context will SUCCEED or be at RISK."""
|
|
if not self.is_trained:
|
|
return {
|
|
"label": "SUCCESS",
|
|
"confidence": 0.5,
|
|
"features_used": {},
|
|
"model_trained": False,
|
|
}
|
|
|
|
KNOWN_BANDS = {
|
|
"MORNING_RUSH", "LATE_MORNING", "LUNCH_RUSH",
|
|
"AFTERNOON", "EVENING_RUSH", "NIGHT", "LATE_NIGHT", "UNKNOWN"
|
|
}
|
|
t_band = (
|
|
timestamp_or_band.upper()
|
|
if timestamp_or_band.upper() in KNOWN_BANDS
|
|
else time_band(timestamp_or_band)
|
|
)
|
|
|
|
features_used = {
|
|
"distance_band": distance_band(distance_km),
|
|
"time_band": t_band,
|
|
"load_band": load_band(avg_load),
|
|
"order_density_band": order_density_band(num_orders, num_riders),
|
|
}
|
|
|
|
label, confidence = self.model.predict(features_used)
|
|
return {
|
|
"label": label,
|
|
"confidence": round(confidence, 4),
|
|
"features_used": features_used,
|
|
"model_trained": True,
|
|
}
|
|
|
|
# ------------------------------------------------------------------
|
|
# Info / Diagnostics
|
|
# ------------------------------------------------------------------
|
|
|
|
def get_info(self) -> Dict[str, Any]:
|
|
return {
|
|
"is_trained": self.is_trained,
|
|
"training_rows": self._training_size,
|
|
"success_rate": round(self._success_rate, 4),
|
|
"n_rules": len(self._rules),
|
|
"rules": self._rules[:20],
|
|
"recent_trend": self._recent_trend,
|
|
"feature_names": self.FEATURES,
|
|
"feature_values": self.model.feature_values if self.is_trained else {},
|
|
"classes": self.model.classes if self.is_trained else [],
|
|
}
|
|
|
|
# ------------------------------------------------------------------
|
|
# Internal helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _fetch_rows(self, limit: int) -> List[Dict]:
|
|
conn = sqlite3.connect(self._db_path)
|
|
conn.row_factory = sqlite3.Row
|
|
rows = conn.execute(
|
|
"SELECT * FROM assignment_ml_log ORDER BY id DESC LIMIT ?", (limit,)
|
|
).fetchall()
|
|
conn.close()
|
|
return [dict(r) for r in rows]
|
|
|
|
def _preprocess(self, rows: List[Dict]) -> Tuple[List[Dict], int]:
|
|
training_data: List[Dict] = []
|
|
successes = 0
|
|
for r in rows:
|
|
try:
|
|
dist_km = float(r.get("total_distance_km") or 0.0)
|
|
ts = str(r.get("timestamp") or "")
|
|
avg_ld = float(r.get("avg_load") or 0.0)
|
|
n_orders = int(r.get("num_orders") or 0)
|
|
n_riders = int(r.get("num_riders") or 1)
|
|
unassigned = int(r.get("unassigned_count") or 0)
|
|
|
|
label = "SUCCESS" if unassigned == 0 else "RISK"
|
|
if label == "SUCCESS":
|
|
successes += 1
|
|
|
|
training_data.append({
|
|
"distance_band": distance_band(dist_km),
|
|
"time_band": time_band(ts),
|
|
"load_band": load_band(avg_ld),
|
|
"order_density_band": order_density_band(n_orders, n_riders),
|
|
self.TARGET: label,
|
|
})
|
|
except Exception:
|
|
continue
|
|
return training_data, successes
|
|
|
|
def _compute_trend(self, rows: List[Dict], window: int = 50) -> None:
|
|
trend = []
|
|
for i in range(0, len(rows), window):
|
|
chunk = rows[i:i + window]
|
|
if not chunk:
|
|
break
|
|
rate = sum(1 for r in chunk if int(r.get("unassigned_count", 1)) == 0) / len(chunk)
|
|
trend.append(round(rate, 4))
|
|
self._recent_trend = trend[-20:]
|
|
|
|
def _save_tree(self) -> None:
|
|
try:
|
|
os.makedirs(os.path.dirname(self._tree_path) or ".", exist_ok=True)
|
|
with open(self._tree_path, "w") as f:
|
|
f.write(self.model.to_json())
|
|
logger.info(f"ID3 tree persisted -> {self._tree_path}")
|
|
except Exception as e:
|
|
logger.warning(f"ID3 tree save failed: {e}")
|
|
|
|
def _load_tree(self) -> None:
|
|
try:
|
|
if not os.path.exists(self._tree_path):
|
|
return
|
|
with open(self._tree_path) as f:
|
|
self.model = ID3Classifier.from_json(f.read())
|
|
self.is_trained = True
|
|
self._rules = self.model.get_tree_rules()
|
|
logger.info(f"ID3 tree restored - rules={len(self._rules)}")
|
|
except Exception as e:
|
|
logger.warning(f"ID3 tree load failed (will retrain): {e}")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Module-level singleton
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_analyzer: Optional[BehaviorAnalyzer] = None
|
|
_analyzer_lock = threading.Lock()
|
|
|
|
|
|
def get_analyzer() -> BehaviorAnalyzer:
|
|
global _analyzer
|
|
with _analyzer_lock:
|
|
if _analyzer is None:
|
|
_analyzer = BehaviorAnalyzer()
|
|
if not _analyzer.is_trained:
|
|
_analyzer.train_on_history()
|
|
return _analyzer
|