initial project setup with README and ignore
This commit is contained in:
311
app/services/ml/behavior_analyzer.py
Normal file
311
app/services/ml/behavior_analyzer.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
Behavior Analyzer - Production Grade
|
||||
======================================
|
||||
Analyzes historical assignment data using the ID3 decision tree to classify
|
||||
assignment outcomes as 'SUCCESS' or 'RISK'.
|
||||
|
||||
Key fixes and upgrades over the original
|
||||
------------------------------------------
|
||||
1. BUG FIX: distance_band now uses `total_distance_km` (not `num_orders`).
|
||||
2. BUG FIX: time_band input is always normalized to uppercase before predict.
|
||||
3. Rich feature set: distance_band, time_band, load_band, order_density_band.
|
||||
4. Returns (label, confidence) from the classifier - exposes uncertainty.
|
||||
5. Trend analysis: tracks rolling success rate over recent N windows.
|
||||
6. Tree persistence: saves/loads trained tree as JSON to survive restarts.
|
||||
7. Feature importance proxy: logs which features drove the split.
|
||||
8. Thread-safe lazy training via a simple lock.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from app.services.ml.id3_classifier import ID3Classifier, get_behavior_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DB_PATH = os.getenv("ML_DB_PATH", "ml_data/ml_store.db")
|
||||
_TREE_PATH = os.getenv("ML_TREE_PATH", "ml_data/behavior_tree.json")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Band encoders (discrete labels for ID3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def distance_band(km: float) -> str:
|
||||
"""Total route distance -> discrete band."""
|
||||
if km <= 5.0: return "SHORT"
|
||||
if km <= 15.0: return "MID"
|
||||
if km <= 30.0: return "LONG"
|
||||
return "VERY_LONG"
|
||||
|
||||
|
||||
def time_band(ts_str: str) -> str:
|
||||
"""ISO timestamp -> time-of-day band."""
|
||||
try:
|
||||
hour = datetime.fromisoformat(ts_str).hour
|
||||
if 6 <= hour < 10: return "MORNING_RUSH"
|
||||
if 10 <= hour < 12: return "LATE_MORNING"
|
||||
if 12 <= hour < 14: return "LUNCH_RUSH"
|
||||
if 14 <= hour < 17: return "AFTERNOON"
|
||||
if 17 <= hour < 20: return "EVENING_RUSH"
|
||||
if 20 <= hour < 23: return "NIGHT"
|
||||
return "LATE_NIGHT"
|
||||
except Exception:
|
||||
return "UNKNOWN"
|
||||
|
||||
|
||||
def load_band(avg_load: float) -> str:
|
||||
"""Average orders-per-rider -> load band."""
|
||||
if avg_load <= 2.0: return "LIGHT"
|
||||
if avg_load <= 5.0: return "MODERATE"
|
||||
if avg_load <= 8.0: return "HEAVY"
|
||||
return "OVERLOADED"
|
||||
|
||||
|
||||
def order_density_band(num_orders: int, num_riders: int) -> str:
|
||||
"""Orders per available rider -> density band."""
|
||||
if num_riders == 0:
|
||||
return "NO_RIDERS"
|
||||
ratio = num_orders / num_riders
|
||||
if ratio <= 2.0: return "SPARSE"
|
||||
if ratio <= 5.0: return "NORMAL"
|
||||
if ratio <= 9.0: return "DENSE"
|
||||
return "OVERLOADED"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Behavior Analyzer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class BehaviorAnalyzer:
|
||||
"""
|
||||
Trains an ID3 tree on historical assignment logs and predicts whether
|
||||
a new assignment context is likely to SUCCEED or be at RISK.
|
||||
|
||||
Features used
|
||||
-------------
|
||||
- distance_band : total route distance bucket
|
||||
- time_band : time-of-day bucket
|
||||
- load_band : average load per rider bucket
|
||||
- order_density_band : orders-per-rider ratio bucket
|
||||
|
||||
Target
|
||||
------
|
||||
- is_success: "SUCCESS" if unassigned_count == 0, else "RISK"
|
||||
"""
|
||||
|
||||
TARGET = "is_success"
|
||||
FEATURES = ["distance_band", "time_band", "load_band", "order_density_band"]
|
||||
|
||||
def __init__(self):
|
||||
self._db_path = _DB_PATH
|
||||
self._tree_path = _TREE_PATH
|
||||
self.model: ID3Classifier = get_behavior_model(max_depth=5)
|
||||
self.is_trained: bool = False
|
||||
self._lock = threading.Lock()
|
||||
self._training_size: int = 0
|
||||
self._success_rate: float = 0.0
|
||||
self._rules: List[str] = []
|
||||
self._recent_trend: List[float] = []
|
||||
|
||||
self._load_tree()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Training
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def train_on_history(self, limit: int = 2000) -> Dict[str, Any]:
|
||||
"""Fetch the most recent rows from SQLite and rebuild the tree."""
|
||||
with self._lock:
|
||||
try:
|
||||
rows = self._fetch_rows(limit)
|
||||
if len(rows) < 10:
|
||||
logger.warning(f"ID3 BehaviorAnalyzer: only {len(rows)} rows - need >=10.")
|
||||
return {"status": "insufficient_data", "rows": len(rows)}
|
||||
|
||||
training_data, successes = self._preprocess(rows)
|
||||
|
||||
if not training_data:
|
||||
return {"status": "preprocess_failed", "rows": len(rows)}
|
||||
|
||||
self.model.train(
|
||||
data=training_data,
|
||||
target=self.TARGET,
|
||||
features=self.FEATURES,
|
||||
)
|
||||
self.is_trained = True
|
||||
self._training_size = len(training_data)
|
||||
self._success_rate = successes / len(training_data)
|
||||
self._rules = self.model.get_tree_rules()
|
||||
self._compute_trend(rows)
|
||||
self._save_tree()
|
||||
|
||||
summary = {
|
||||
"status": "ok",
|
||||
"training_rows": self._training_size,
|
||||
"success_rate": round(self._success_rate, 4),
|
||||
"n_rules": len(self._rules),
|
||||
"classes": self.model.classes,
|
||||
"feature_values": self.model.feature_values,
|
||||
}
|
||||
logger.info(
|
||||
f"ID3 BehaviorAnalyzer trained - rows={self._training_size}, "
|
||||
f"success_rate={self._success_rate:.1%}, rules={len(self._rules)}"
|
||||
)
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ID3 BehaviorAnalyzer training failed: {e}", exc_info=True)
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prediction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def predict(self, distance_km: float, timestamp_or_band: str,
|
||||
avg_load: float = 4.0, num_orders: int = 5,
|
||||
num_riders: int = 2) -> Dict[str, Any]:
|
||||
"""Predict whether an assignment context will SUCCEED or be at RISK."""
|
||||
if not self.is_trained:
|
||||
return {
|
||||
"label": "SUCCESS",
|
||||
"confidence": 0.5,
|
||||
"features_used": {},
|
||||
"model_trained": False,
|
||||
}
|
||||
|
||||
KNOWN_BANDS = {
|
||||
"MORNING_RUSH", "LATE_MORNING", "LUNCH_RUSH",
|
||||
"AFTERNOON", "EVENING_RUSH", "NIGHT", "LATE_NIGHT", "UNKNOWN"
|
||||
}
|
||||
t_band = (
|
||||
timestamp_or_band.upper()
|
||||
if timestamp_or_band.upper() in KNOWN_BANDS
|
||||
else time_band(timestamp_or_band)
|
||||
)
|
||||
|
||||
features_used = {
|
||||
"distance_band": distance_band(distance_km),
|
||||
"time_band": t_band,
|
||||
"load_band": load_band(avg_load),
|
||||
"order_density_band": order_density_band(num_orders, num_riders),
|
||||
}
|
||||
|
||||
label, confidence = self.model.predict(features_used)
|
||||
return {
|
||||
"label": label,
|
||||
"confidence": round(confidence, 4),
|
||||
"features_used": features_used,
|
||||
"model_trained": True,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Info / Diagnostics
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_info(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"is_trained": self.is_trained,
|
||||
"training_rows": self._training_size,
|
||||
"success_rate": round(self._success_rate, 4),
|
||||
"n_rules": len(self._rules),
|
||||
"rules": self._rules[:20],
|
||||
"recent_trend": self._recent_trend,
|
||||
"feature_names": self.FEATURES,
|
||||
"feature_values": self.model.feature_values if self.is_trained else {},
|
||||
"classes": self.model.classes if self.is_trained else [],
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _fetch_rows(self, limit: int) -> List[Dict]:
|
||||
conn = sqlite3.connect(self._db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM assignment_ml_log ORDER BY id DESC LIMIT ?", (limit,)
|
||||
).fetchall()
|
||||
conn.close()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def _preprocess(self, rows: List[Dict]) -> Tuple[List[Dict], int]:
|
||||
training_data: List[Dict] = []
|
||||
successes = 0
|
||||
for r in rows:
|
||||
try:
|
||||
dist_km = float(r.get("total_distance_km") or 0.0)
|
||||
ts = str(r.get("timestamp") or "")
|
||||
avg_ld = float(r.get("avg_load") or 0.0)
|
||||
n_orders = int(r.get("num_orders") or 0)
|
||||
n_riders = int(r.get("num_riders") or 1)
|
||||
unassigned = int(r.get("unassigned_count") or 0)
|
||||
|
||||
label = "SUCCESS" if unassigned == 0 else "RISK"
|
||||
if label == "SUCCESS":
|
||||
successes += 1
|
||||
|
||||
training_data.append({
|
||||
"distance_band": distance_band(dist_km),
|
||||
"time_band": time_band(ts),
|
||||
"load_band": load_band(avg_ld),
|
||||
"order_density_band": order_density_band(n_orders, n_riders),
|
||||
self.TARGET: label,
|
||||
})
|
||||
except Exception:
|
||||
continue
|
||||
return training_data, successes
|
||||
|
||||
def _compute_trend(self, rows: List[Dict], window: int = 50) -> None:
|
||||
trend = []
|
||||
for i in range(0, len(rows), window):
|
||||
chunk = rows[i:i + window]
|
||||
if not chunk:
|
||||
break
|
||||
rate = sum(1 for r in chunk if int(r.get("unassigned_count", 1)) == 0) / len(chunk)
|
||||
trend.append(round(rate, 4))
|
||||
self._recent_trend = trend[-20:]
|
||||
|
||||
def _save_tree(self) -> None:
|
||||
try:
|
||||
os.makedirs(os.path.dirname(self._tree_path) or ".", exist_ok=True)
|
||||
with open(self._tree_path, "w") as f:
|
||||
f.write(self.model.to_json())
|
||||
logger.info(f"ID3 tree persisted -> {self._tree_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"ID3 tree save failed: {e}")
|
||||
|
||||
def _load_tree(self) -> None:
|
||||
try:
|
||||
if not os.path.exists(self._tree_path):
|
||||
return
|
||||
with open(self._tree_path) as f:
|
||||
self.model = ID3Classifier.from_json(f.read())
|
||||
self.is_trained = True
|
||||
self._rules = self.model.get_tree_rules()
|
||||
logger.info(f"ID3 tree restored - rules={len(self._rules)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"ID3 tree load failed (will retrain): {e}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_analyzer: Optional[BehaviorAnalyzer] = None
|
||||
_analyzer_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_analyzer() -> BehaviorAnalyzer:
|
||||
global _analyzer
|
||||
with _analyzer_lock:
|
||||
if _analyzer is None:
|
||||
_analyzer = BehaviorAnalyzer()
|
||||
if not _analyzer.is_trained:
|
||||
_analyzer.train_on_history()
|
||||
return _analyzer
|
||||
Reference in New Issue
Block a user