540 lines
23 KiB
Python
540 lines
23 KiB
Python
"""
|
|
ML Data Collector - Production Grade
|
|
======================================
|
|
Logs every assignment call (inputs + outcomes) to SQLite.
|
|
|
|
Key upgrades over the original
|
|
--------------------------------
|
|
1. FROZEN historical scores - quality_score is written ONCE at log time.
|
|
get_training_data() returns scores as-is from the DB (no retroactive mutation).
|
|
2. Rich schema - zone_id, city_id, is_peak, weather_code,
|
|
sla_breached, avg_delivery_time_min for richer features.
|
|
3. SLA tracking - logs whether delivery SLA was breached.
|
|
4. Analytics API - get_hourly_stats(), get_strategy_comparison(),
|
|
get_quality_histogram(), get_zone_stats() for dashboard consumption.
|
|
5. Thread-safe writes - connection-per-write pattern for FastAPI workers.
|
|
6. Indexed columns - timestamp, ml_strategy, zone_id for fast queries.
|
|
"""
|
|
|
|
import csv
|
|
import io
|
|
import logging
|
|
import os
|
|
import sqlite3
|
|
import threading
|
|
from datetime import datetime, timedelta
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_DB_PATH = os.getenv("ML_DB_PATH", "ml_data/ml_store.db")
|
|
_WRITE_LOCK = threading.Lock()
|
|
|
|
|
|
def _std(values: List[float]) -> float:
|
|
if len(values) < 2:
|
|
return 0.0
|
|
mean = sum(values) / len(values)
|
|
return (sum((v - mean) ** 2 for v in values) / len(values)) ** 0.5
|
|
|
|
|
|
class MLDataCollector:
|
|
"""
|
|
Event logger for assignment service calls.
|
|
|
|
Each log_assignment_event() call writes one row capturing:
|
|
- Operating context (time, orders, riders, zone, city)
|
|
- Active hyperparams (exact config snapshot for this call)
|
|
- Measured outcomes (quality score, SLA, latency, distances)
|
|
|
|
quality_score is computed once and FROZEN - never retroactively changed.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._db_path = _DB_PATH
|
|
self._ensure_db()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Main logging API
|
|
# ------------------------------------------------------------------
|
|
|
|
def log_assignment_event(
|
|
self,
|
|
*,
|
|
num_orders: int,
|
|
num_riders: int,
|
|
hyperparams: Dict[str, Any],
|
|
assignments: Dict[int, List[Any]],
|
|
unassigned_count: int,
|
|
elapsed_ms: float,
|
|
zone_id: str = "default",
|
|
city_id: str = "default",
|
|
weather_code: str = "CLEAR",
|
|
sla_minutes: Optional[float] = None,
|
|
avg_delivery_time_min: Optional[float] = None,
|
|
) -> None:
|
|
"""
|
|
Log one assignment event.
|
|
|
|
Call this at the END of AssignmentService.assign_orders() once
|
|
outcomes are known.
|
|
"""
|
|
try:
|
|
now = datetime.utcnow()
|
|
hour = now.hour
|
|
day_of_week = now.weekday()
|
|
is_peak = int(hour in (7, 8, 9, 12, 13, 18, 19, 20))
|
|
|
|
rider_loads = [len(orders) for orders in assignments.values() if orders]
|
|
riders_used = len(rider_loads)
|
|
total_assigned = sum(rider_loads)
|
|
avg_load = total_assigned / riders_used if riders_used else 0.0
|
|
load_std = _std(rider_loads) if rider_loads else 0.0
|
|
|
|
all_orders = [o for orders in assignments.values() if orders for o in orders]
|
|
total_distance_km = sum(self._get_km(o) for o in all_orders)
|
|
ml_strategy = hyperparams.get("ml_strategy", "balanced")
|
|
max_opr = hyperparams.get("max_orders_per_rider", 12)
|
|
|
|
sla_breached = 0
|
|
if sla_minutes and avg_delivery_time_min:
|
|
sla_breached = int(avg_delivery_time_min > sla_minutes)
|
|
|
|
# Quality score - FROZEN at log time
|
|
quality_score = self._compute_quality_score(
|
|
num_orders=num_orders,
|
|
unassigned_count=unassigned_count,
|
|
load_std=load_std,
|
|
riders_used=riders_used,
|
|
num_riders=num_riders,
|
|
total_distance_km=total_distance_km,
|
|
max_orders_per_rider=max_opr,
|
|
ml_strategy=ml_strategy,
|
|
)
|
|
|
|
row = {
|
|
"timestamp": now.isoformat(),
|
|
"hour": hour,
|
|
"day_of_week": day_of_week,
|
|
"is_peak": is_peak,
|
|
"zone_id": zone_id,
|
|
"city_id": city_id,
|
|
"weather_code": weather_code,
|
|
"num_orders": num_orders,
|
|
"num_riders": num_riders,
|
|
"max_pickup_distance_km": hyperparams.get("max_pickup_distance_km", 10.0),
|
|
"max_kitchen_distance_km": hyperparams.get("max_kitchen_distance_km", 3.0),
|
|
"max_orders_per_rider": max_opr,
|
|
"ideal_load": hyperparams.get("ideal_load", 6),
|
|
"workload_balance_threshold": hyperparams.get("workload_balance_threshold", 0.7),
|
|
"workload_penalty_weight": hyperparams.get("workload_penalty_weight", 100.0),
|
|
"distance_penalty_weight": hyperparams.get("distance_penalty_weight", 2.0),
|
|
"cluster_radius_km": hyperparams.get("cluster_radius_km", 3.0),
|
|
"search_time_limit_seconds": hyperparams.get("search_time_limit_seconds", 5),
|
|
"road_factor": hyperparams.get("road_factor", 1.3),
|
|
"ml_strategy": ml_strategy,
|
|
"riders_used": riders_used,
|
|
"total_assigned": total_assigned,
|
|
"unassigned_count": unassigned_count,
|
|
"avg_load": round(avg_load, 3),
|
|
"load_std": round(load_std, 3),
|
|
"total_distance_km": round(total_distance_km, 2),
|
|
"elapsed_ms": round(elapsed_ms, 1),
|
|
"sla_breached": sla_breached,
|
|
"avg_delivery_time_min": round(avg_delivery_time_min or 0.0, 2),
|
|
"quality_score": round(quality_score, 2),
|
|
}
|
|
|
|
with _WRITE_LOCK:
|
|
self._insert(row)
|
|
|
|
logger.info(
|
|
f"[MLCollector] zone={zone_id} orders={num_orders} "
|
|
f"assigned={total_assigned} unassigned={unassigned_count} "
|
|
f"quality={quality_score:.1f} elapsed={elapsed_ms:.0f}ms"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"[MLCollector] Logging failed (non-fatal): {e}")
|
|
|
|
# ------------------------------------------------------------------
|
|
# Data retrieval for training
|
|
# ------------------------------------------------------------------
|
|
|
|
def get_training_data(
|
|
self,
|
|
min_records: int = 30,
|
|
strategy_filter: Optional[str] = None,
|
|
since_hours: Optional[int] = None,
|
|
) -> Optional[List[Dict[str, Any]]]:
|
|
"""
|
|
Return logged rows for model training.
|
|
quality_score is returned AS-IS (frozen at log time - no re-scoring).
|
|
"""
|
|
try:
|
|
conn = sqlite3.connect(self._db_path)
|
|
conn.row_factory = sqlite3.Row
|
|
|
|
query = "SELECT * FROM assignment_ml_log"
|
|
params: list = []
|
|
clauses: list = []
|
|
|
|
if strategy_filter:
|
|
clauses.append("ml_strategy = ?")
|
|
params.append(strategy_filter)
|
|
if since_hours:
|
|
cutoff = (datetime.utcnow() - timedelta(hours=since_hours)).isoformat()
|
|
clauses.append("timestamp >= ?")
|
|
params.append(cutoff)
|
|
|
|
if clauses:
|
|
query += " WHERE " + " AND ".join(clauses)
|
|
query += " ORDER BY id ASC"
|
|
|
|
rows = conn.execute(query, params).fetchall()
|
|
conn.close()
|
|
|
|
if len(rows) < min_records:
|
|
logger.info(f"[MLCollector] {len(rows)} records < {min_records} minimum.")
|
|
return None
|
|
|
|
return [dict(r) for r in rows]
|
|
|
|
except Exception as e:
|
|
logger.error(f"[MLCollector] get_training_data failed: {e}")
|
|
return None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Analytics API
|
|
# ------------------------------------------------------------------
|
|
|
|
def get_recent_quality_trend(self, last_n: int = 50) -> Dict[str, Any]:
|
|
"""Recent quality scores + series for sparkline charts."""
|
|
try:
|
|
conn = sqlite3.connect(self._db_path)
|
|
rows = conn.execute(
|
|
"SELECT quality_score, timestamp, unassigned_count, elapsed_ms "
|
|
"FROM assignment_ml_log ORDER BY id DESC LIMIT ?", (last_n,)
|
|
).fetchall()
|
|
conn.close()
|
|
if not rows:
|
|
return {"avg_quality": 0.0, "sample_size": 0, "history": []}
|
|
scores = [r[0] for r in rows]
|
|
return {
|
|
"avg_quality": round(sum(scores) / len(scores), 2),
|
|
"min_quality": round(min(scores), 2),
|
|
"max_quality": round(max(scores), 2),
|
|
"sample_size": len(scores),
|
|
"history": list(reversed(scores)),
|
|
"timestamps": list(reversed([r[1] for r in rows])),
|
|
"unassigned_series": list(reversed([r[2] for r in rows])),
|
|
"latency_series": list(reversed([r[3] for r in rows])),
|
|
}
|
|
except Exception:
|
|
return {"avg_quality": 0.0, "sample_size": 0, "history": []}
|
|
|
|
def get_hourly_stats(self, last_days: int = 7) -> List[Dict[str, Any]]:
|
|
"""Quality, SLA, and call volume aggregated by hour-of-day."""
|
|
try:
|
|
conn = sqlite3.connect(self._db_path)
|
|
cutoff = (datetime.utcnow() - timedelta(days=last_days)).isoformat()
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT hour,
|
|
COUNT(*) AS call_count,
|
|
AVG(quality_score) AS avg_quality,
|
|
AVG(unassigned_count) AS avg_unassigned,
|
|
AVG(elapsed_ms) AS avg_latency_ms,
|
|
SUM(CASE WHEN sla_breached=1 THEN 1 ELSE 0 END) AS sla_breaches
|
|
FROM assignment_ml_log WHERE timestamp >= ?
|
|
GROUP BY hour ORDER BY hour
|
|
""", (cutoff,)
|
|
).fetchall()
|
|
conn.close()
|
|
return [
|
|
{
|
|
"hour": r[0],
|
|
"call_count": r[1],
|
|
"avg_quality": round(r[2] or 0.0, 2),
|
|
"avg_unassigned": round(r[3] or 0.0, 2),
|
|
"avg_latency_ms": round(r[4] or 0.0, 1),
|
|
"sla_breaches": r[5],
|
|
}
|
|
for r in rows
|
|
]
|
|
except Exception as e:
|
|
logger.error(f"[MLCollector] get_hourly_stats: {e}")
|
|
return []
|
|
|
|
def get_strategy_comparison(self) -> List[Dict[str, Any]]:
|
|
"""Compare quality metrics across ml_strategy values."""
|
|
try:
|
|
conn = sqlite3.connect(self._db_path)
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT ml_strategy,
|
|
COUNT(*) AS call_count,
|
|
AVG(quality_score) AS avg_quality,
|
|
MIN(quality_score) AS min_quality,
|
|
MAX(quality_score) AS max_quality,
|
|
AVG(unassigned_count) AS avg_unassigned,
|
|
AVG(total_distance_km) AS avg_distance_km,
|
|
AVG(elapsed_ms) AS avg_latency_ms
|
|
FROM assignment_ml_log
|
|
GROUP BY ml_strategy ORDER BY avg_quality DESC
|
|
"""
|
|
).fetchall()
|
|
conn.close()
|
|
return [
|
|
{
|
|
"strategy": r[0],
|
|
"call_count": r[1],
|
|
"avg_quality": round(r[2] or 0.0, 2),
|
|
"min_quality": round(r[3] or 0.0, 2),
|
|
"max_quality": round(r[4] or 0.0, 2),
|
|
"avg_unassigned": round(r[5] or 0.0, 2),
|
|
"avg_distance_km": round(r[6] or 0.0, 2),
|
|
"avg_latency_ms": round(r[7] or 0.0, 1),
|
|
}
|
|
for r in rows
|
|
]
|
|
except Exception as e:
|
|
logger.error(f"[MLCollector] get_strategy_comparison: {e}")
|
|
return []
|
|
|
|
def get_quality_histogram(self, bins: int = 10) -> List[Dict[str, Any]]:
|
|
"""Quality score distribution for histogram chart."""
|
|
try:
|
|
conn = sqlite3.connect(self._db_path)
|
|
rows = conn.execute("SELECT quality_score FROM assignment_ml_log").fetchall()
|
|
conn.close()
|
|
scores = [r[0] for r in rows if r[0] is not None]
|
|
if not scores:
|
|
return []
|
|
bin_width = 100.0 / bins
|
|
return [
|
|
{
|
|
"range": f"{i*bin_width:.0f}-{(i+1)*bin_width:.0f}",
|
|
"count": sum(1 for s in scores if i*bin_width <= s < (i+1)*bin_width)
|
|
}
|
|
for i in range(bins)
|
|
]
|
|
except Exception as e:
|
|
logger.error(f"[MLCollector] get_quality_histogram: {e}")
|
|
return []
|
|
|
|
def get_zone_stats(self) -> List[Dict[str, Any]]:
|
|
"""Quality and SLA stats grouped by zone."""
|
|
try:
|
|
conn = sqlite3.connect(self._db_path)
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT zone_id, COUNT(*) AS call_count,
|
|
AVG(quality_score) AS avg_quality,
|
|
SUM(sla_breached) AS sla_breaches,
|
|
AVG(total_distance_km) AS avg_distance_km
|
|
FROM assignment_ml_log
|
|
GROUP BY zone_id ORDER BY avg_quality DESC
|
|
"""
|
|
).fetchall()
|
|
conn.close()
|
|
return [
|
|
{
|
|
"zone_id": r[0],
|
|
"call_count": r[1],
|
|
"avg_quality": round(r[2] or 0.0, 2),
|
|
"sla_breaches": r[3],
|
|
"avg_distance_km": round(r[4] or 0.0, 2),
|
|
}
|
|
for r in rows
|
|
]
|
|
except Exception as e:
|
|
logger.error(f"[MLCollector] get_zone_stats: {e}")
|
|
return []
|
|
|
|
def count_records(self) -> int:
|
|
try:
|
|
conn = sqlite3.connect(self._db_path)
|
|
count = conn.execute("SELECT COUNT(*) FROM assignment_ml_log").fetchone()[0]
|
|
conn.close()
|
|
return count
|
|
except Exception:
|
|
return 0
|
|
|
|
def count_by_strategy(self) -> Dict[str, int]:
|
|
try:
|
|
conn = sqlite3.connect(self._db_path)
|
|
rows = conn.execute(
|
|
"SELECT ml_strategy, COUNT(*) FROM assignment_ml_log GROUP BY ml_strategy"
|
|
).fetchall()
|
|
conn.close()
|
|
return {r[0]: r[1] for r in rows}
|
|
except Exception:
|
|
return {}
|
|
|
|
def export_csv(self) -> str:
|
|
"""Export all records as CSV string."""
|
|
try:
|
|
conn = sqlite3.connect(self._db_path)
|
|
conn.row_factory = sqlite3.Row
|
|
rows = conn.execute("SELECT * FROM assignment_ml_log ORDER BY id ASC").fetchall()
|
|
conn.close()
|
|
if not rows:
|
|
return ""
|
|
buf = io.StringIO()
|
|
writer = csv.DictWriter(buf, fieldnames=rows[0].keys())
|
|
writer.writeheader()
|
|
writer.writerows([dict(r) for r in rows])
|
|
return buf.getvalue()
|
|
except Exception as e:
|
|
logger.error(f"[MLCollector] export_csv failed: {e}")
|
|
return ""
|
|
|
|
def purge_old_records(self, keep_days: int = 90) -> int:
|
|
"""Delete records older than keep_days. Returns count deleted."""
|
|
try:
|
|
cutoff = (datetime.utcnow() - timedelta(days=keep_days)).isoformat()
|
|
conn = sqlite3.connect(self._db_path)
|
|
cursor = conn.execute(
|
|
"DELETE FROM assignment_ml_log WHERE timestamp < ?", (cutoff,)
|
|
)
|
|
deleted = cursor.rowcount
|
|
conn.commit()
|
|
conn.close()
|
|
logger.info(f"[MLCollector] Purged {deleted} records older than {keep_days} days.")
|
|
return deleted
|
|
except Exception as e:
|
|
logger.error(f"[MLCollector] purge failed: {e}")
|
|
return 0
|
|
|
|
# ------------------------------------------------------------------
|
|
# Quality Score Formula (frozen at log time - do not change behavior)
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _compute_quality_score(
|
|
num_orders: int, unassigned_count: int, load_std: float,
|
|
riders_used: int, num_riders: int, total_distance_km: float,
|
|
max_orders_per_rider: int, ml_strategy: str = "balanced",
|
|
) -> float:
|
|
if num_orders == 0:
|
|
return 0.0
|
|
assigned_ratio = 1.0 - (unassigned_count / num_orders)
|
|
max_std = max(1.0, max_orders_per_rider / 2.0)
|
|
balance_ratio = max(0.0, 1.0 - (load_std / max_std))
|
|
max_dist = max(1.0, float((num_orders - unassigned_count) * 8.0))
|
|
distance_ratio = max(0.0, 1.0 - (total_distance_km / max_dist))
|
|
weights = {
|
|
"aggressive_speed": (80.0, 20.0, 0.0),
|
|
"fuel_saver": (30.0, 70.0, 0.0),
|
|
"zone_strict": (40.0, 30.0, 30.0),
|
|
"balanced": (50.0, 25.0, 25.0),
|
|
}
|
|
w_comp, w_dist, w_bal = weights.get(ml_strategy, (50.0, 25.0, 25.0))
|
|
return min(
|
|
assigned_ratio * w_comp + distance_ratio * w_dist + balance_ratio * w_bal,
|
|
100.0,
|
|
)
|
|
|
|
@staticmethod
|
|
def _get_km(order: Any) -> float:
|
|
try:
|
|
return float(order.get("kms") or order.get("calculationDistanceKm") or 0.0)
|
|
except Exception:
|
|
return 0.0
|
|
|
|
# ------------------------------------------------------------------
|
|
# DB Bootstrap
|
|
# ------------------------------------------------------------------
|
|
|
|
def _ensure_db(self) -> None:
|
|
try:
|
|
os.makedirs(os.path.dirname(self._db_path) or ".", exist_ok=True)
|
|
conn = sqlite3.connect(self._db_path)
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS assignment_ml_log (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
timestamp TEXT NOT NULL,
|
|
hour INTEGER,
|
|
day_of_week INTEGER,
|
|
is_peak INTEGER DEFAULT 0,
|
|
zone_id TEXT DEFAULT 'default',
|
|
city_id TEXT DEFAULT 'default',
|
|
weather_code TEXT DEFAULT 'CLEAR',
|
|
num_orders INTEGER,
|
|
num_riders INTEGER,
|
|
max_pickup_distance_km REAL,
|
|
max_kitchen_distance_km REAL,
|
|
max_orders_per_rider INTEGER,
|
|
ideal_load INTEGER,
|
|
workload_balance_threshold REAL,
|
|
workload_penalty_weight REAL,
|
|
distance_penalty_weight REAL,
|
|
cluster_radius_km REAL,
|
|
search_time_limit_seconds INTEGER,
|
|
road_factor REAL,
|
|
ml_strategy TEXT DEFAULT 'balanced',
|
|
riders_used INTEGER,
|
|
total_assigned INTEGER,
|
|
unassigned_count INTEGER,
|
|
avg_load REAL,
|
|
load_std REAL,
|
|
total_distance_km REAL DEFAULT 0.0,
|
|
elapsed_ms REAL,
|
|
sla_breached INTEGER DEFAULT 0,
|
|
avg_delivery_time_min REAL DEFAULT 0.0,
|
|
quality_score REAL
|
|
)
|
|
""")
|
|
migrations = [
|
|
"ALTER TABLE assignment_ml_log ADD COLUMN is_peak INTEGER DEFAULT 0",
|
|
"ALTER TABLE assignment_ml_log ADD COLUMN zone_id TEXT DEFAULT 'default'",
|
|
"ALTER TABLE assignment_ml_log ADD COLUMN city_id TEXT DEFAULT 'default'",
|
|
"ALTER TABLE assignment_ml_log ADD COLUMN weather_code TEXT DEFAULT 'CLEAR'",
|
|
"ALTER TABLE assignment_ml_log ADD COLUMN sla_breached INTEGER DEFAULT 0",
|
|
"ALTER TABLE assignment_ml_log ADD COLUMN avg_delivery_time_min REAL DEFAULT 0.0",
|
|
"ALTER TABLE assignment_ml_log ADD COLUMN ml_strategy TEXT DEFAULT 'balanced'",
|
|
"ALTER TABLE assignment_ml_log ADD COLUMN total_distance_km REAL DEFAULT 0.0",
|
|
]
|
|
for ddl in migrations:
|
|
try:
|
|
conn.execute(ddl)
|
|
except Exception:
|
|
pass
|
|
for idx in [
|
|
"CREATE INDEX IF NOT EXISTS idx_timestamp ON assignment_ml_log(timestamp)",
|
|
"CREATE INDEX IF NOT EXISTS idx_strategy ON assignment_ml_log(ml_strategy)",
|
|
"CREATE INDEX IF NOT EXISTS idx_zone ON assignment_ml_log(zone_id)",
|
|
]:
|
|
conn.execute(idx)
|
|
conn.commit()
|
|
conn.close()
|
|
except Exception as e:
|
|
logger.error(f"[MLCollector] DB init failed: {e}")
|
|
|
|
def _insert(self, row: Dict[str, Any]) -> None:
|
|
os.makedirs(os.path.dirname(self._db_path) or ".", exist_ok=True)
|
|
conn = sqlite3.connect(self._db_path)
|
|
cols = ", ".join(row.keys())
|
|
placeholders = ", ".join(["?"] * len(row))
|
|
conn.execute(
|
|
f"INSERT INTO assignment_ml_log ({cols}) VALUES ({placeholders})",
|
|
list(row.values()),
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Module-level singleton
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_collector: Optional[MLDataCollector] = None
|
|
|
|
|
|
def get_collector() -> MLDataCollector:
|
|
global _collector
|
|
if _collector is None:
|
|
_collector = MLDataCollector()
|
|
return _collector
|