549 lines
17 KiB
Python
549 lines
17 KiB
Python
from fastapi import FastAPI
|
|
from pydantic import BaseModel
|
|
import json
|
|
import requests
|
|
import re
|
|
import traceback
|
|
import duckdb
|
|
import math
|
|
import os
|
|
|
|
from typing import Optional, Any, TypedDict, List, Dict, cast
|
|
from datetime import datetime, date
|
|
from decimal import Decimal
|
|
import time
|
|
|
|
# ---------------------------------------------------
|
|
# SCHEMAS (Pydantic for strict enforcement)
|
|
# ---------------------------------------------------
|
|
|
|
class ChartConfig(BaseModel):
|
|
type: str # e.g., "bar", "pie", "line", "area", "table", "kpi"
|
|
title: str
|
|
x_axis: Optional[str] = None
|
|
y_axis: Optional[str] = None
|
|
data: List[Dict[str, Any]]
|
|
|
|
class AnalyticsResponse(BaseModel):
|
|
analysis: str
|
|
charts: List[ChartConfig]
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from langgraph.graph import StateGraph, END
|
|
|
|
# ---------------------------------------------------
|
|
# CONFIG
|
|
# ---------------------------------------------------
|
|
|
|
OLLAMA_URL = "http://localhost:11434/api/chat"
|
|
MODEL = "qwen2.5:3b"
|
|
# DuckDB should connect to a local persistent file
|
|
DUCKDB_PATH = "analytics.duckdb"
|
|
DEBUG_LOG = "debug.log"
|
|
|
|
# ---------------------------------------------------
|
|
# FASTAPI
|
|
# ---------------------------------------------------
|
|
|
|
app = FastAPI()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# ---------------------------------------------------
|
|
# DEBUG LOGGER
|
|
# ---------------------------------------------------
|
|
|
|
def log_debug(msg: str):
|
|
with open(DEBUG_LOG, "a", encoding="utf-8") as f:
|
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
f.write(f"[{timestamp}] {msg}\n")
|
|
|
|
def log_node_duration(label: str, start_time: float, extra: str = ""):
|
|
duration = time.perf_counter() - start_time
|
|
print(f"[{label}] {duration:.2f}s {extra}")
|
|
|
|
# ---------------------------------------------------
|
|
# DUCKDB INIT
|
|
# ---------------------------------------------------
|
|
|
|
def init_duckdb(conn):
|
|
try:
|
|
log_debug("Initializing DuckDB with live S3 data")
|
|
|
|
# Load httpfs extension for S3 access
|
|
conn.execute("INSTALL httpfs; LOAD httpfs;")
|
|
|
|
# S3 configuration matching the environment
|
|
conn.execute("SET s3_region='sgp1';")
|
|
conn.execute("SET s3_endpoint='sgp1.digitaloceanspaces.com';")
|
|
conn.execute("SET s3_url_style='path';")
|
|
|
|
# Create a view from the S3 parquet files
|
|
s3_path = 's3://nearle/parquet/deliveries/*.parquet'
|
|
conn.execute(f"""
|
|
CREATE OR REPLACE VIEW deliveries AS
|
|
SELECT * FROM read_parquet('{s3_path}', union_by_name = true)
|
|
""")
|
|
|
|
log_debug(f"DuckDB ready with view 'deliveries' from {s3_path}")
|
|
except Exception as e:
|
|
log_debug(f"DuckDB init error: {e}")
|
|
pass
|
|
|
|
# ---------------------------------------------------
|
|
# REQUEST MODEL
|
|
# ---------------------------------------------------
|
|
|
|
class ChatRequest(BaseModel):
|
|
message: str
|
|
data: Optional[Any] = None
|
|
|
|
# ---------------------------------------------------
|
|
# LANGGRAPH STATE
|
|
# ---------------------------------------------------
|
|
|
|
class WorkflowState(TypedDict):
|
|
question: str
|
|
query_plan: Dict[str, Any]
|
|
sql_query: str
|
|
data: List[Dict[str, Any]]
|
|
analysis: str
|
|
insights: List[str]
|
|
data_quality: Dict[str, Any]
|
|
final_response: Dict[str, Any]
|
|
error: Optional[str]
|
|
iteration_count: int
|
|
|
|
# ---------------------------------------------------
|
|
# OLLAMA CALL
|
|
# ---------------------------------------------------
|
|
|
|
def call_ollama(system_prompt: str, user_prompt: str) -> str:
|
|
for attempt in range(3):
|
|
try:
|
|
payload = {
|
|
"model": MODEL,
|
|
"messages": [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_prompt}
|
|
],
|
|
"stream": False,
|
|
"options": {
|
|
"temperature": 0,
|
|
"num_predict": 200, # 🔥 increased to prevent cutoff
|
|
}
|
|
}
|
|
|
|
r = requests.post(OLLAMA_URL, json=payload, timeout=20)
|
|
r.raise_for_status()
|
|
|
|
content = r.json().get("message", {}).get("content", "")
|
|
if content and content.strip() != "":
|
|
return content
|
|
log_debug(f"Ollama empty response on attempt {attempt + 1}")
|
|
except Exception as e:
|
|
log_debug(f"Ollama call error: {e} on attempt {attempt + 1}")
|
|
|
|
time.sleep(1)
|
|
return ""
|
|
|
|
# ---------------------------------------------------
|
|
# UTILITIES
|
|
# ---------------------------------------------------
|
|
|
|
def safe_parse_json(text: str) -> Dict[str, Any]:
|
|
"""Robustly parse JSON from LLM response."""
|
|
try:
|
|
return json.loads(text.strip())
|
|
except json.JSONDecodeError:
|
|
match = re.search(r"```json\s*(.*?)\s*```", text, re.S)
|
|
if match:
|
|
try:
|
|
return json.loads(match.group(1).strip())
|
|
except: pass
|
|
|
|
start = text.find("{")
|
|
end = text.rfind("}")
|
|
if start != -1 and end != -1:
|
|
try:
|
|
return json.loads(text[start:end+1])
|
|
except: pass
|
|
raise ValueError("Could not find valid JSON in LLM response")
|
|
|
|
# ---------------------------------------------------
|
|
# NODES
|
|
# ---------------------------------------------------
|
|
|
|
def query_planner_node(state: WorkflowState) -> WorkflowState:
|
|
log_debug(f"NODE: Query Planner - {state['question']}")
|
|
print(f"\n🔍 NODE: PLANNER (Iteration: {state.get('iteration_count', 0)})")
|
|
start = time.perf_counter()
|
|
|
|
system = """You are an AI SQL Planner.
|
|
|
|
Return ONLY valid JSON. No explanation.
|
|
|
|
---
|
|
|
|
## DATABASE
|
|
|
|
Table: deliveries
|
|
|
|
Columns:
|
|
|
|
* orderid
|
|
* tenantname
|
|
* deliveryamt (can be TEXT or NUMERIC)
|
|
* deliverydate
|
|
|
|
---
|
|
|
|
## RULES
|
|
|
|
1. Revenue = SUM(CAST(deliveryamt AS NUMERIC))
|
|
|
|
2. Never use:
|
|
|
|
* orderamount
|
|
* unsafe filters
|
|
* TRIM unless needed
|
|
|
|
3. SQL must NEVER fail:
|
|
|
|
* Always CAST for SUM
|
|
* Keep query simple
|
|
|
|
4. If previous query failed:
|
|
|
|
* Use ERROR below to FIX SQL
|
|
* NEVER repeat same query
|
|
|
|
ERROR:
|
|
{error}
|
|
|
|
---
|
|
|
|
## OUTPUT FORMAT
|
|
|
|
You MUST return a JSON object containing EXACTLY these 6 keys:
|
|
|
|
{
|
|
"intent": "kpi | trend | comparison | table",
|
|
"sql": "valid SQL query",
|
|
"chart_type": "kpi | bar | line | pie | table",
|
|
"x_axis": "column or null",
|
|
"y_axis": "column or null",
|
|
"explanation": "short explanation"
|
|
}"""
|
|
|
|
error_msg = state.get("error", "")
|
|
system = system.replace("{error}", str(error_msg))
|
|
|
|
error_msg = state.get("error")
|
|
if error_msg and state.get("iteration_count", 0) > 0:
|
|
log_debug(f"Retrying SQL generation due to error: {error_msg}")
|
|
user_prompt = f"User Query:\n{state['question']}\n\nWarning: Your previous SQL query failed with this error in DuckDB:\n{error_msg}\n\nPlease fix the SQL query and return the correct JSON plan."
|
|
else:
|
|
user_prompt = f"User Query:\n{state['question']}"
|
|
|
|
result = call_ollama(system, user_prompt)
|
|
|
|
print("\n--- RAW LLM RESPONSE ---")
|
|
print(result)
|
|
print("------------------------\n")
|
|
print("RESPONSE LENGTH:", len(result))
|
|
|
|
if not result or result.strip() == "":
|
|
raise Exception("LLM returned empty response")
|
|
|
|
try:
|
|
plan = safe_parse_json(result)
|
|
|
|
# Strict validation
|
|
for key in ["intent", "sql", "chart_type", "y_axis"]:
|
|
if key not in plan:
|
|
raise ValueError(f"Missing required key: {key}")
|
|
|
|
if plan.get("intent") == "ranking" and "GROUP BY" not in str(plan.get("sql", "")).upper():
|
|
raise ValueError("Ranking query SQL must contain a GROUP BY clause")
|
|
|
|
print("\n--- PARSED PLAN ---")
|
|
print(json.dumps(plan, indent=2))
|
|
print("-------------------\n")
|
|
|
|
print("\n--- FINAL SQL USED ---")
|
|
print(plan["sql"])
|
|
print("----------------------\n")
|
|
|
|
except Exception as e:
|
|
log_debug(f"Planner JSON parsing error: {str(e)} for output: {result}")
|
|
print(f"\n❌ ERROR: Invalid LLM response. Original error: {str(e)}")
|
|
raise Exception("Invalid LLM response")
|
|
|
|
log_node_duration("⏱️ Planner", start)
|
|
return {**state, "query_plan": plan, "sql_query": plan.get("sql", ""), "error": None}
|
|
|
|
def execute_sql_node(state: WorkflowState) -> WorkflowState:
|
|
log_debug(f"NODE: Execute SQL - {state['sql_query']}")
|
|
print(f"\n🗄️ NODE: SQL EXECUTION")
|
|
start = time.perf_counter()
|
|
|
|
try:
|
|
conn = duckdb.connect(DUCKDB_PATH)
|
|
init_duckdb(conn)
|
|
|
|
df = conn.execute(state["sql_query"]).df()
|
|
sql_results = df.to_dict("records")
|
|
conn.close()
|
|
|
|
# Sanitize data for JSON response
|
|
clean_data = []
|
|
for r in sql_results:
|
|
row_clean = {}
|
|
for k, v in r.items():
|
|
if isinstance(v, Decimal): v = float(v)
|
|
if isinstance(v, float) and (math.isnan(v) or math.isinf(v)): v = 0.0
|
|
if isinstance(v, (datetime, date)):
|
|
# Format as clean YYYY-MM-DD to save chart axis space
|
|
v = v.strftime("%Y-%m-%d")
|
|
|
|
# Replace bad string primitives from DB with recognizable null equivalent
|
|
if v is None or (isinstance(v, str) and v.strip() == ""):
|
|
# Ignore for numeric indicators
|
|
if not ("orders" in k or "revenue" in k or "avg_order_value" in k or "charges" in k):
|
|
v = "Unknown"
|
|
row_clean[k] = v
|
|
clean_data.append(row_clean)
|
|
|
|
log_node_duration("⏱️ QUERY EXECUTION", start, f"({len(clean_data)} rows)")
|
|
return {**state, "data": clean_data, "error": None, "iteration_count": int(state.get("iteration_count", 0) or 0) + 1}
|
|
except Exception as e:
|
|
log_debug(f"SQL Error: {e}")
|
|
log_node_duration("⏱️ QUERY EXECUTION", start, "(ERROR)")
|
|
return {**state, "error": str(e), "iteration_count": int(state.get("iteration_count", 0) or 0) + 1}
|
|
|
|
def validate_data_node(state: WorkflowState) -> WorkflowState:
|
|
log_debug("NODE: Data Quality Analysis")
|
|
warnings = []
|
|
|
|
if not state.get("data"):
|
|
warnings.append("No data returned for your query.")
|
|
return {**state, "data_quality": {"warnings": warnings}}
|
|
|
|
all_zeros = True
|
|
for row in state["data"]:
|
|
for k, v in row.items():
|
|
if isinstance(v, (int, float)) and v > 0:
|
|
all_zeros = False
|
|
break
|
|
if not all_zeros: break
|
|
|
|
if all_zeros:
|
|
warnings.append("All metrics returned zero values.")
|
|
|
|
return {**state, "data_quality": {"warnings": warnings}}
|
|
|
|
def analyze_results_node(state: WorkflowState) -> WorkflowState:
|
|
log_debug("NODE: Results Analysis")
|
|
data = state.get("data", [])
|
|
if not data or state.get("error"):
|
|
return {**state, "analysis": "No data found to analyze."}
|
|
|
|
row_count = len(data)
|
|
plan = state.get("query_plan", {})
|
|
metric = plan.get("y_axis")
|
|
dimension = plan.get("x_axis")
|
|
|
|
if dimension and metric and row_count > 0:
|
|
top = data[0]
|
|
analysis = f"Analysis of {row_count} records shows that {dimension} '{(top.get(dimension, 'Unknown'))}' is the top entry with {metric} of {top.get(metric, 0)}."
|
|
elif metric and row_count > 0:
|
|
top = data[0]
|
|
analysis = f"Analysis completed. The total {metric} is {top.get(metric, 0)}."
|
|
else:
|
|
analysis = f"Retrieved {row_count} records for the requested analysis."
|
|
|
|
return {**state, "analysis": analysis}
|
|
|
|
def insights_node(state: WorkflowState) -> WorkflowState:
|
|
log_debug("NODE: Insights Generation")
|
|
print(f"\n💡 NODE: INSIGHTS")
|
|
start = time.perf_counter()
|
|
data = cast(list, state.get("data", []))
|
|
if not data or state.get("error"):
|
|
return {**state, "insights": []}
|
|
|
|
insights = []
|
|
plan = state.get("query_plan", {})
|
|
metric = plan.get("y_axis")
|
|
dimension = plan.get("x_axis")
|
|
|
|
if metric and dimension and len(data) > 0:
|
|
top = data[0]
|
|
insights.append(f"Top performer is {top.get(dimension, 'Unknown')} with {top.get(metric, 0)} {metric}.")
|
|
if len(data) > 1:
|
|
bottom = data[-1]
|
|
insights.append(f"Lowest performer is {bottom.get(dimension, 'Unknown')} with {bottom.get(metric, 0)} {metric}.")
|
|
try:
|
|
total = sum(float(row.get(metric, 0)) for row in data if row.get(metric) is not None)
|
|
avg_val = total / len(data)
|
|
insights.append(f"The average {metric} across these records is {avg_val:.2f}.")
|
|
except:
|
|
insights.append("Data volume supports high-confidence trend analysis.")
|
|
elif metric and len(data) > 0:
|
|
top = data[0]
|
|
insights.append(f"The overall {metric} is {top.get(metric, 0)}.")
|
|
else:
|
|
insights = ["Data supports standard trend analysis."]
|
|
|
|
if state.get("data_quality") and state["data_quality"].get("warnings"):
|
|
for w in state["data_quality"]["warnings"]:
|
|
insights.append(f"⚠ {w}")
|
|
|
|
log_node_duration("⏱️ Insights", start)
|
|
return {**state, "insights": insights}
|
|
|
|
def select_visualization_node(state: WorkflowState) -> WorkflowState:
|
|
log_debug("NODE: Visualization Selection")
|
|
print(f"\n📊 NODE: VISUALIZATION")
|
|
start = time.perf_counter()
|
|
|
|
data = state.get("data", [])
|
|
if not data or state.get("error"):
|
|
return {**state, "final_response": {
|
|
"analysis": state.get("error", "No data available to plot a chart."),
|
|
"charts": []
|
|
}}
|
|
|
|
plan = state.get("query_plan", {})
|
|
intent = plan.get("intent", "kpi")
|
|
|
|
chart_type = plan.get("chart_type", "kpi")
|
|
x_axis = plan.get("x_axis")
|
|
y_axis = plan.get("y_axis")
|
|
|
|
if intent and "table" in intent.lower():
|
|
chart_type = "table"
|
|
|
|
if intent and "list" in intent.lower():
|
|
chart_type = "table"
|
|
|
|
if chart_type == "kpi" and len(data) == 1:
|
|
data = [{**data[0], "summary": "Current Performance"}]
|
|
if not x_axis:
|
|
x_axis = "summary"
|
|
|
|
if chart_type == "kpi" and len(data) > 1:
|
|
chart_type = "bar"
|
|
|
|
chart = {
|
|
"type": chart_type,
|
|
"title": f"{str(intent).replace('_', ' ').title()} Analysis" if intent else "Data Visualization",
|
|
"x_axis": x_axis,
|
|
"y_axis": y_axis,
|
|
"data": data
|
|
}
|
|
|
|
ai_decision = {
|
|
"intent": intent,
|
|
"chart_type": chart_type,
|
|
"x_axis": x_axis,
|
|
"y_axis": y_axis,
|
|
"reason": "Directly mapped from LLM Engine structured plan.",
|
|
}
|
|
|
|
final_response = {
|
|
"analysis": state["analysis"],
|
|
"ai_decision": ai_decision,
|
|
"insights": state.get("insights", []),
|
|
"generated_sql": state.get("sql_query", ""),
|
|
"query_plan": plan,
|
|
"data_quality": state.get("data_quality", {"warnings": []}),
|
|
"charts": [chart],
|
|
"explanation": state["analysis"]
|
|
}
|
|
|
|
log_node_duration("⏱️ Visualization", start)
|
|
return {**state, "final_response": final_response}
|
|
|
|
def should_retry(state: WorkflowState):
|
|
if state.get("error") and state.get("iteration_count", 0) < 2:
|
|
return "retry"
|
|
return "continue"
|
|
|
|
# ---------------------------------------------------
|
|
# LANGGRAPH WORKFLOW
|
|
# ---------------------------------------------------
|
|
|
|
workflow = StateGraph(WorkflowState)
|
|
|
|
workflow.add_node("planner", query_planner_node)
|
|
workflow.add_node("exec", execute_sql_node)
|
|
workflow.add_node("validate", validate_data_node)
|
|
workflow.add_node("analysis", analyze_results_node)
|
|
workflow.add_node("insights", insights_node)
|
|
workflow.add_node("viz", select_visualization_node)
|
|
|
|
workflow.set_entry_point("planner")
|
|
|
|
workflow.add_edge("planner", "exec")
|
|
|
|
# Conditional edge for retry
|
|
workflow.add_conditional_edges(
|
|
"exec",
|
|
should_retry,
|
|
{
|
|
"retry": "planner", # Re-run planner if error
|
|
"continue": "validate"
|
|
}
|
|
)
|
|
|
|
workflow.add_edge("validate", "analysis")
|
|
workflow.add_edge("analysis", "insights")
|
|
workflow.add_edge("insights", "viz")
|
|
workflow.add_edge("viz", END)
|
|
|
|
graph = workflow.compile()
|
|
|
|
# ---------------------------------------------------
|
|
# API
|
|
# ---------------------------------------------------
|
|
|
|
@app.post("/chat")
|
|
def chat(req: ChatRequest):
|
|
print(f"\n🚀 NEW REQUEST: {req.message}")
|
|
total_start = time.perf_counter()
|
|
try:
|
|
initial_state = {
|
|
"question": req.message,
|
|
"query_plan": {},
|
|
"sql_query": "",
|
|
"data": [],
|
|
"analysis": "",
|
|
"insights": [],
|
|
"data_quality": {"warnings": []},
|
|
"final_response": {},
|
|
"error": None,
|
|
"iteration_count": 0
|
|
}
|
|
result = graph.invoke(initial_state)
|
|
total_duration = time.perf_counter() - total_start
|
|
print(f"\n[✅ TOTAL TIME] {total_duration:.2f}s")
|
|
return result["final_response"]
|
|
|
|
except Exception as e:
|
|
log_debug(f"Graph overall error: {e}")
|
|
return {
|
|
"analysis": f"Error: {str(e)}",
|
|
"charts": []
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |