287 lines
11 KiB
Python
287 lines
11 KiB
Python
"""
|
|
ML Admin API - rider-api
|
|
|
|
Endpoints:
|
|
GET /api/v1/ml/status - DB record count, quality trend, model info
|
|
GET /api/v1/ml/config - Current active hyperparameters (ML-tuned + defaults)
|
|
POST /api/v1/ml/train - Trigger hypertuning immediately
|
|
POST /api/v1/ml/reset - Reset config to factory defaults
|
|
GET /api/v1/ml/reports - List past tuning reports
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import json
|
|
from fastapi import APIRouter, HTTPException, Body, Request
|
|
from fastapi.responses import FileResponse, PlainTextResponse
|
|
from typing import Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(
|
|
prefix="/api/v1/ml",
|
|
tags=["ML Hypertuner"],
|
|
responses={
|
|
500: {"description": "Internal server error"}
|
|
}
|
|
)
|
|
|
|
web_router = APIRouter(
|
|
tags=["ML Monitor Web Dashboard"]
|
|
)
|
|
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# GET /ml-ops
|
|
# -----------------------------------------------------------------------------
|
|
|
|
@web_router.get("/ml-ops", summary="Visual ML monitoring dashboard")
|
|
def ml_dashboard():
|
|
"""Returns the beautiful HTML dashboard for visualizing ML progress."""
|
|
path = os.path.join(os.getcwd(), "app/templates/ml_dashboard.html")
|
|
if not os.path.isfile(path):
|
|
raise HTTPException(status_code=404, detail=f"Dashboard template not found at {path}")
|
|
return FileResponse(path)
|
|
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# GET /status
|
|
# -----------------------------------------------------------------------------
|
|
|
|
@router.get("/status", summary="ML system status & quality trend")
|
|
def ml_status():
|
|
"""
|
|
Returns:
|
|
- How many assignment events are logged
|
|
- Recent quality score trend (avg / min / max over last 20 calls)
|
|
- Whether the model has been trained
|
|
- Current hyperparameter source (ml_tuned vs defaults)
|
|
"""
|
|
from app.services.ml.ml_data_collector import get_collector
|
|
from app.services.ml.ml_hypertuner import get_hypertuner
|
|
|
|
try:
|
|
collector = get_collector()
|
|
tuner = get_hypertuner()
|
|
|
|
record_count = collector.count_records()
|
|
quality_trend = collector.get_recent_quality_trend(last_n=50)
|
|
model_info = tuner.get_model_info()
|
|
|
|
from app.services.ml.behavior_analyzer import get_analyzer
|
|
b_analyzer = get_analyzer()
|
|
|
|
from app.config.dynamic_config import get_config
|
|
cfg = get_config()
|
|
|
|
return {
|
|
"status": "ok",
|
|
"db_records": record_count,
|
|
"ready_to_train": record_count >= 30,
|
|
"quality_trend": quality_trend,
|
|
"hourly_stats": collector.get_hourly_stats(),
|
|
"quality_histogram": collector.get_quality_histogram(),
|
|
"strategy_comparison": collector.get_strategy_comparison(),
|
|
"zone_stats": collector.get_zone_stats(),
|
|
"behavior": b_analyzer.get_info() if hasattr(b_analyzer, 'get_info') else {},
|
|
"config": cfg.get_all(),
|
|
"model": model_info,
|
|
"message": (
|
|
f"Collecting data - need {max(0, 30 - record_count)} more records to train."
|
|
if record_count < 30
|
|
else "Ready to train! Call POST /api/v1/ml/train"
|
|
if not model_info["model_trained"]
|
|
else "Model trained and active."
|
|
)
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"[ML API] Status failed: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# GET /config
|
|
# -----------------------------------------------------------------------------
|
|
|
|
@router.get("/config", summary="Current active hyperparameter values")
|
|
def ml_config():
|
|
"""
|
|
Returns every hyperparameter currently in use by the system.
|
|
Values marked 'ml_tuned' were set by the ML model.
|
|
Values marked 'default' are factory defaults (not yet tuned).
|
|
"""
|
|
from app.config.dynamic_config import get_config, DEFAULTS
|
|
|
|
try:
|
|
cfg = get_config()
|
|
all_values = cfg.get_all()
|
|
cached_keys = set(cfg._cache.keys())
|
|
|
|
annotated = {}
|
|
for k, v in all_values.items():
|
|
annotated[k] = {
|
|
"value": v,
|
|
"source": "ml_tuned" if k in cached_keys else "default",
|
|
}
|
|
|
|
return {
|
|
"status": "ok",
|
|
"hyperparameters": annotated,
|
|
"total_params": len(annotated),
|
|
"ml_tuned_count": sum(1 for x in annotated.values() if x["source"] == "ml_tuned"),
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"[ML API] Config fetch failed: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.patch("/config", summary="Update specific ML configuration defaults")
|
|
def ml_config_patch(payload: dict = Body(...)):
|
|
"""Allows updating any active parameter via JSON overrides. e.g. \{ \"ml_strategy\": \"balanced\" \}"""
|
|
from app.config.dynamic_config import get_config
|
|
try:
|
|
cfg = get_config()
|
|
cfg.set_bulk(payload, source="ml_admin")
|
|
return {"status": "ok"}
|
|
except Exception as e:
|
|
logger.error(f"[ML API] Config patch failed: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# POST /train
|
|
# -----------------------------------------------------------------------------
|
|
|
|
@router.post("/train", summary="Trigger XGBoost training + Optuna hyperparameter search")
|
|
def ml_train(
|
|
n_trials: int = Body(default=100, embed=True, ge=10, le=500,
|
|
description="Number of Optuna trials (10500)"),
|
|
min_records: int = Body(default=30, embed=True, ge=10,
|
|
description="Minimum DB records required")
|
|
):
|
|
"""
|
|
Runs the full hypertuning pipeline:
|
|
1. Load logged assignment data from DB
|
|
2. Train XGBoost surrogate model
|
|
3. Run Optuna TPE search ({n_trials} trials)
|
|
4. Write optimal params to DynamicConfig
|
|
|
|
The AssignmentService picks up new params within 5 minutes (auto-reload).
|
|
"""
|
|
from app.services.ml.ml_hypertuner import get_hypertuner
|
|
|
|
try:
|
|
logger.info(f"[ML API] Hypertuning triggered: n_trials={n_trials}, min_records={min_records}")
|
|
tuner = get_hypertuner()
|
|
result = tuner.run(n_trials=n_trials, min_training_records=min_records)
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"[ML API] Training failed: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# POST /reset
|
|
# -----------------------------------------------------------------------------
|
|
|
|
@router.post("/reset", summary="Reset all hyperparameters to factory defaults")
|
|
def ml_reset():
|
|
"""
|
|
Wipes all ML-tuned config values and reverts every parameter to the
|
|
original hardcoded defaults. Useful if the model produced bad results.
|
|
"""
|
|
from app.config.dynamic_config import get_config
|
|
|
|
try:
|
|
get_config().reset_to_defaults()
|
|
return {
|
|
"status": "ok",
|
|
"message": "All hyperparameters reset to factory defaults.",
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"[ML API] Reset failed: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# POST /strategy
|
|
# -----------------------------------------------------------------------------
|
|
|
|
@router.post("/strategy", summary="Change the AI Optimization Prompt/Strategy")
|
|
def ml_strategy(strategy: str = Body(default="balanced", embed=True)):
|
|
"""
|
|
Changes the mathematical objective of the AI.
|
|
Choices: 'balanced', 'fuel_saver', 'aggressive_speed', 'zone_strict'
|
|
|
|
Historical data is NOT wiped. Instead, the AI dynamically recalculates
|
|
the quality score of all past events using the new strategy rules.
|
|
"""
|
|
from app.config.dynamic_config import get_config
|
|
import sqlite3
|
|
|
|
valid = ["balanced", "fuel_saver", "aggressive_speed", "zone_strict"]
|
|
if strategy not in valid:
|
|
raise HTTPException(400, f"Invalid strategy. Choose from {valid}")
|
|
|
|
try:
|
|
get_config().set("ml_strategy", strategy)
|
|
|
|
return {
|
|
"status": "ok",
|
|
"message": f"Strategy changed to '{strategy}'. Historical AI data will be mathematically repurposed to train towards this new goal.",
|
|
"strategy": strategy
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"[ML API] Strategy change failed: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# GET /reports
|
|
# -----------------------------------------------------------------------------
|
|
|
|
@router.get("/reports", summary="List past hypertuning reports")
|
|
def ml_reports():
|
|
"""Returns the last 10 tuning reports (JSON files in ml_data/reports/)."""
|
|
try:
|
|
report_dir = "ml_data/reports"
|
|
if not os.path.isdir(report_dir):
|
|
return {"status": "ok", "reports": [], "message": "No reports yet."}
|
|
|
|
files = sorted(
|
|
[f for f in os.listdir(report_dir) if f.endswith(".json")],
|
|
reverse=True
|
|
)[:10]
|
|
|
|
reports = []
|
|
for fname in files:
|
|
path = os.path.join(report_dir, fname)
|
|
try:
|
|
with open(path) as f:
|
|
reports.append(json.load(f))
|
|
except Exception:
|
|
pass
|
|
|
|
return {"status": "ok", "reports": reports, "count": len(reports)}
|
|
except Exception as e:
|
|
logger.error(f"[ML API] Reports fetch failed: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# GET /export
|
|
# -----------------------------------------------------------------------------
|
|
|
|
@router.get("/export", summary="Export all records as CSV")
|
|
def ml_export():
|
|
"""Generates a CSV string containing all rows in the assignment_ml_log table."""
|
|
try:
|
|
from app.services.ml.ml_data_collector import get_collector
|
|
csv_data = get_collector().export_csv()
|
|
response = PlainTextResponse(content=csv_data, media_type="text/csv")
|
|
response.headers["Content-Disposition"] = 'attachment; filename="ml_export.csv"'
|
|
return response
|
|
except Exception as e:
|
|
logger.error(f"[ML API] Export failed: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|