initial project setup with README and ignore
This commit is contained in:
286
app/routes/ml_admin.py
Normal file
286
app/routes/ml_admin.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
ML Admin API - rider-api
|
||||
|
||||
Endpoints:
|
||||
GET /api/v1/ml/status - DB record count, quality trend, model info
|
||||
GET /api/v1/ml/config - Current active hyperparameters (ML-tuned + defaults)
|
||||
POST /api/v1/ml/train - Trigger hypertuning immediately
|
||||
POST /api/v1/ml/reset - Reset config to factory defaults
|
||||
GET /api/v1/ml/reports - List past tuning reports
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from fastapi import APIRouter, HTTPException, Body, Request
|
||||
from fastapi.responses import FileResponse, PlainTextResponse
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/v1/ml",
|
||||
tags=["ML Hypertuner"],
|
||||
responses={
|
||||
500: {"description": "Internal server error"}
|
||||
}
|
||||
)
|
||||
|
||||
web_router = APIRouter(
|
||||
tags=["ML Monitor Web Dashboard"]
|
||||
)
|
||||
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# GET /ml-ops
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@web_router.get("/ml-ops", summary="Visual ML monitoring dashboard")
|
||||
def ml_dashboard():
|
||||
"""Returns the beautiful HTML dashboard for visualizing ML progress."""
|
||||
path = os.path.join(os.getcwd(), "app/templates/ml_dashboard.html")
|
||||
if not os.path.isfile(path):
|
||||
raise HTTPException(status_code=404, detail=f"Dashboard template not found at {path}")
|
||||
return FileResponse(path)
|
||||
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# GET /status
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@router.get("/status", summary="ML system status & quality trend")
|
||||
def ml_status():
|
||||
"""
|
||||
Returns:
|
||||
- How many assignment events are logged
|
||||
- Recent quality score trend (avg / min / max over last 20 calls)
|
||||
- Whether the model has been trained
|
||||
- Current hyperparameter source (ml_tuned vs defaults)
|
||||
"""
|
||||
from app.services.ml.ml_data_collector import get_collector
|
||||
from app.services.ml.ml_hypertuner import get_hypertuner
|
||||
|
||||
try:
|
||||
collector = get_collector()
|
||||
tuner = get_hypertuner()
|
||||
|
||||
record_count = collector.count_records()
|
||||
quality_trend = collector.get_recent_quality_trend(last_n=50)
|
||||
model_info = tuner.get_model_info()
|
||||
|
||||
from app.services.ml.behavior_analyzer import get_analyzer
|
||||
b_analyzer = get_analyzer()
|
||||
|
||||
from app.config.dynamic_config import get_config
|
||||
cfg = get_config()
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"db_records": record_count,
|
||||
"ready_to_train": record_count >= 30,
|
||||
"quality_trend": quality_trend,
|
||||
"hourly_stats": collector.get_hourly_stats(),
|
||||
"quality_histogram": collector.get_quality_histogram(),
|
||||
"strategy_comparison": collector.get_strategy_comparison(),
|
||||
"zone_stats": collector.get_zone_stats(),
|
||||
"behavior": b_analyzer.get_info() if hasattr(b_analyzer, 'get_info') else {},
|
||||
"config": cfg.get_all(),
|
||||
"model": model_info,
|
||||
"message": (
|
||||
f"Collecting data - need {max(0, 30 - record_count)} more records to train."
|
||||
if record_count < 30
|
||||
else "Ready to train! Call POST /api/v1/ml/train"
|
||||
if not model_info["model_trained"]
|
||||
else "Model trained and active."
|
||||
)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[ML API] Status failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# GET /config
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@router.get("/config", summary="Current active hyperparameter values")
|
||||
def ml_config():
|
||||
"""
|
||||
Returns every hyperparameter currently in use by the system.
|
||||
Values marked 'ml_tuned' were set by the ML model.
|
||||
Values marked 'default' are factory defaults (not yet tuned).
|
||||
"""
|
||||
from app.config.dynamic_config import get_config, DEFAULTS
|
||||
|
||||
try:
|
||||
cfg = get_config()
|
||||
all_values = cfg.get_all()
|
||||
cached_keys = set(cfg._cache.keys())
|
||||
|
||||
annotated = {}
|
||||
for k, v in all_values.items():
|
||||
annotated[k] = {
|
||||
"value": v,
|
||||
"source": "ml_tuned" if k in cached_keys else "default",
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"hyperparameters": annotated,
|
||||
"total_params": len(annotated),
|
||||
"ml_tuned_count": sum(1 for x in annotated.values() if x["source"] == "ml_tuned"),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[ML API] Config fetch failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.patch("/config", summary="Update specific ML configuration defaults")
|
||||
def ml_config_patch(payload: dict = Body(...)):
|
||||
"""Allows updating any active parameter via JSON overrides. e.g. \{ \"ml_strategy\": \"balanced\" \}"""
|
||||
from app.config.dynamic_config import get_config
|
||||
try:
|
||||
cfg = get_config()
|
||||
cfg.set_bulk(payload, source="ml_admin")
|
||||
return {"status": "ok"}
|
||||
except Exception as e:
|
||||
logger.error(f"[ML API] Config patch failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# POST /train
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@router.post("/train", summary="Trigger XGBoost training + Optuna hyperparameter search")
|
||||
def ml_train(
|
||||
n_trials: int = Body(default=100, embed=True, ge=10, le=500,
|
||||
description="Number of Optuna trials (10500)"),
|
||||
min_records: int = Body(default=30, embed=True, ge=10,
|
||||
description="Minimum DB records required")
|
||||
):
|
||||
"""
|
||||
Runs the full hypertuning pipeline:
|
||||
1. Load logged assignment data from DB
|
||||
2. Train XGBoost surrogate model
|
||||
3. Run Optuna TPE search ({n_trials} trials)
|
||||
4. Write optimal params to DynamicConfig
|
||||
|
||||
The AssignmentService picks up new params within 5 minutes (auto-reload).
|
||||
"""
|
||||
from app.services.ml.ml_hypertuner import get_hypertuner
|
||||
|
||||
try:
|
||||
logger.info(f"[ML API] Hypertuning triggered: n_trials={n_trials}, min_records={min_records}")
|
||||
tuner = get_hypertuner()
|
||||
result = tuner.run(n_trials=n_trials, min_training_records=min_records)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"[ML API] Training failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# POST /reset
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@router.post("/reset", summary="Reset all hyperparameters to factory defaults")
|
||||
def ml_reset():
|
||||
"""
|
||||
Wipes all ML-tuned config values and reverts every parameter to the
|
||||
original hardcoded defaults. Useful if the model produced bad results.
|
||||
"""
|
||||
from app.config.dynamic_config import get_config
|
||||
|
||||
try:
|
||||
get_config().reset_to_defaults()
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": "All hyperparameters reset to factory defaults.",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[ML API] Reset failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# POST /strategy
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@router.post("/strategy", summary="Change the AI Optimization Prompt/Strategy")
|
||||
def ml_strategy(strategy: str = Body(default="balanced", embed=True)):
|
||||
"""
|
||||
Changes the mathematical objective of the AI.
|
||||
Choices: 'balanced', 'fuel_saver', 'aggressive_speed', 'zone_strict'
|
||||
|
||||
Historical data is NOT wiped. Instead, the AI dynamically recalculates
|
||||
the quality score of all past events using the new strategy rules.
|
||||
"""
|
||||
from app.config.dynamic_config import get_config
|
||||
import sqlite3
|
||||
|
||||
valid = ["balanced", "fuel_saver", "aggressive_speed", "zone_strict"]
|
||||
if strategy not in valid:
|
||||
raise HTTPException(400, f"Invalid strategy. Choose from {valid}")
|
||||
|
||||
try:
|
||||
get_config().set("ml_strategy", strategy)
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": f"Strategy changed to '{strategy}'. Historical AI data will be mathematically repurposed to train towards this new goal.",
|
||||
"strategy": strategy
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[ML API] Strategy change failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# GET /reports
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@router.get("/reports", summary="List past hypertuning reports")
|
||||
def ml_reports():
|
||||
"""Returns the last 10 tuning reports (JSON files in ml_data/reports/)."""
|
||||
try:
|
||||
report_dir = "ml_data/reports"
|
||||
if not os.path.isdir(report_dir):
|
||||
return {"status": "ok", "reports": [], "message": "No reports yet."}
|
||||
|
||||
files = sorted(
|
||||
[f for f in os.listdir(report_dir) if f.endswith(".json")],
|
||||
reverse=True
|
||||
)[:10]
|
||||
|
||||
reports = []
|
||||
for fname in files:
|
||||
path = os.path.join(report_dir, fname)
|
||||
try:
|
||||
with open(path) as f:
|
||||
reports.append(json.load(f))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"status": "ok", "reports": reports, "count": len(reports)}
|
||||
except Exception as e:
|
||||
logger.error(f"[ML API] Reports fetch failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# GET /export
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@router.get("/export", summary="Export all records as CSV")
|
||||
def ml_export():
|
||||
"""Generates a CSV string containing all rows in the assignment_ml_log table."""
|
||||
try:
|
||||
from app.services.ml.ml_data_collector import get_collector
|
||||
csv_data = get_collector().export_csv()
|
||||
response = PlainTextResponse(content=csv_data, media_type="text/csv")
|
||||
response.headers["Content-Disposition"] = 'attachment; filename="ml_export.csv"'
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"[ML API] Export failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
Reference in New Issue
Block a user