from fastapi import FastAPI from pydantic import BaseModel import json import requests import re import traceback import duckdb import math import os from typing import Optional, Any, TypedDict, List, Dict, cast from datetime import datetime, date from decimal import Decimal import time # --------------------------------------------------- # SCHEMAS (Pydantic for strict enforcement) # --------------------------------------------------- class ChartConfig(BaseModel): type: str # e.g., "bar", "pie", "line", "area", "table", "kpi" title: str x_axis: Optional[str] = None y_axis: Optional[str] = None data: List[Dict[str, Any]] class AnalyticsResponse(BaseModel): analysis: str charts: List[ChartConfig] from fastapi.middleware.cors import CORSMiddleware from langgraph.graph import StateGraph, END # --------------------------------------------------- # CONFIG # --------------------------------------------------- OLLAMA_URL = "http://localhost:11434/api/chat" MODEL = "qwen2.5:3b" # DuckDB should connect to a local persistent file DUCKDB_PATH = "analytics.duckdb" DEBUG_LOG = "debug.log" # --------------------------------------------------- # FASTAPI # --------------------------------------------------- app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --------------------------------------------------- # DEBUG LOGGER # --------------------------------------------------- def log_debug(msg: str): with open(DEBUG_LOG, "a", encoding="utf-8") as f: timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") f.write(f"[{timestamp}] {msg}\n") def log_node_duration(label: str, start_time: float, extra: str = ""): duration = time.perf_counter() - start_time print(f"[{label}] {duration:.2f}s {extra}") # --------------------------------------------------- # DUCKDB INIT # --------------------------------------------------- def init_duckdb(conn): try: log_debug("Initializing DuckDB with live S3 data") # Load httpfs extension for S3 access conn.execute("INSTALL httpfs; LOAD httpfs;") # S3 configuration matching the environment conn.execute("SET s3_region='sgp1';") conn.execute("SET s3_endpoint='sgp1.digitaloceanspaces.com';") conn.execute("SET s3_url_style='path';") # Create a view from the S3 parquet files s3_path = 's3://nearle/parquet/deliveries/*.parquet' conn.execute(f""" CREATE OR REPLACE VIEW deliveries AS SELECT * FROM read_parquet('{s3_path}', union_by_name = true) """) log_debug(f"DuckDB ready with view 'deliveries' from {s3_path}") except Exception as e: log_debug(f"DuckDB init error: {e}") pass # --------------------------------------------------- # REQUEST MODEL # --------------------------------------------------- class ChatRequest(BaseModel): message: str data: Optional[Any] = None # --------------------------------------------------- # LANGGRAPH STATE # --------------------------------------------------- class WorkflowState(TypedDict): question: str query_plan: Dict[str, Any] sql_query: str data: List[Dict[str, Any]] analysis: str insights: List[str] data_quality: Dict[str, Any] final_response: Dict[str, Any] error: Optional[str] iteration_count: int # --------------------------------------------------- # OLLAMA CALL # --------------------------------------------------- def call_ollama(system_prompt: str, user_prompt: str) -> str: for attempt in range(3): try: payload = { "model": MODEL, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ], "stream": False, "options": { "temperature": 0, "num_predict": 200, # 🔥 increased to prevent cutoff } } r = requests.post(OLLAMA_URL, json=payload, timeout=20) r.raise_for_status() content = r.json().get("message", {}).get("content", "") if content and content.strip() != "": return content log_debug(f"Ollama empty response on attempt {attempt + 1}") except Exception as e: log_debug(f"Ollama call error: {e} on attempt {attempt + 1}") time.sleep(1) return "" # --------------------------------------------------- # UTILITIES # --------------------------------------------------- def safe_parse_json(text: str) -> Dict[str, Any]: """Robustly parse JSON from LLM response.""" try: return json.loads(text.strip()) except json.JSONDecodeError: match = re.search(r"```json\s*(.*?)\s*```", text, re.S) if match: try: return json.loads(match.group(1).strip()) except: pass start = text.find("{") end = text.rfind("}") if start != -1 and end != -1: try: return json.loads(text[start:end+1]) except: pass raise ValueError("Could not find valid JSON in LLM response") # --------------------------------------------------- # NODES # --------------------------------------------------- def query_planner_node(state: WorkflowState) -> WorkflowState: log_debug(f"NODE: Query Planner - {state['question']}") print(f"\n🔍 NODE: PLANNER (Iteration: {state.get('iteration_count', 0)})") start = time.perf_counter() system = """You are an AI SQL Planner. Return ONLY valid JSON. No explanation. --- ## DATABASE Table: deliveries Columns: * orderid * tenantname * deliveryamt (can be TEXT or NUMERIC) * deliverydate --- ## RULES 1. Revenue = SUM(CAST(deliveryamt AS NUMERIC)) 2. Never use: * orderamount * unsafe filters * TRIM unless needed 3. SQL must NEVER fail: * Always CAST for SUM * Keep query simple 4. If previous query failed: * Use ERROR below to FIX SQL * NEVER repeat same query ERROR: {error} --- ## OUTPUT FORMAT You MUST return a JSON object containing EXACTLY these 6 keys: { "intent": "kpi | trend | comparison | table", "sql": "valid SQL query", "chart_type": "kpi | bar | line | pie | table", "x_axis": "column or null", "y_axis": "column or null", "explanation": "short explanation" }""" error_msg = state.get("error", "") system = system.replace("{error}", str(error_msg)) error_msg = state.get("error") if error_msg and state.get("iteration_count", 0) > 0: log_debug(f"Retrying SQL generation due to error: {error_msg}") user_prompt = f"User Query:\n{state['question']}\n\nWarning: Your previous SQL query failed with this error in DuckDB:\n{error_msg}\n\nPlease fix the SQL query and return the correct JSON plan." else: user_prompt = f"User Query:\n{state['question']}" result = call_ollama(system, user_prompt) print("\n--- RAW LLM RESPONSE ---") print(result) print("------------------------\n") print("RESPONSE LENGTH:", len(result)) if not result or result.strip() == "": raise Exception("LLM returned empty response") try: plan = safe_parse_json(result) # Strict validation for key in ["intent", "sql", "chart_type", "y_axis"]: if key not in plan: raise ValueError(f"Missing required key: {key}") if plan.get("intent") == "ranking" and "GROUP BY" not in str(plan.get("sql", "")).upper(): raise ValueError("Ranking query SQL must contain a GROUP BY clause") print("\n--- PARSED PLAN ---") print(json.dumps(plan, indent=2)) print("-------------------\n") print("\n--- FINAL SQL USED ---") print(plan["sql"]) print("----------------------\n") except Exception as e: log_debug(f"Planner JSON parsing error: {str(e)} for output: {result}") print(f"\n❌ ERROR: Invalid LLM response. Original error: {str(e)}") raise Exception("Invalid LLM response") log_node_duration("⏱️ Planner", start) return {**state, "query_plan": plan, "sql_query": plan.get("sql", ""), "error": None} def execute_sql_node(state: WorkflowState) -> WorkflowState: log_debug(f"NODE: Execute SQL - {state['sql_query']}") print(f"\n🗄️ NODE: SQL EXECUTION") start = time.perf_counter() try: conn = duckdb.connect(DUCKDB_PATH) init_duckdb(conn) df = conn.execute(state["sql_query"]).df() sql_results = df.to_dict("records") conn.close() # Sanitize data for JSON response clean_data = [] for r in sql_results: row_clean = {} for k, v in r.items(): if isinstance(v, Decimal): v = float(v) if isinstance(v, float) and (math.isnan(v) or math.isinf(v)): v = 0.0 if isinstance(v, (datetime, date)): # Format as clean YYYY-MM-DD to save chart axis space v = v.strftime("%Y-%m-%d") # Replace bad string primitives from DB with recognizable null equivalent if v is None or (isinstance(v, str) and v.strip() == ""): # Ignore for numeric indicators if not ("orders" in k or "revenue" in k or "avg_order_value" in k or "charges" in k): v = "Unknown" row_clean[k] = v clean_data.append(row_clean) log_node_duration("⏱️ QUERY EXECUTION", start, f"({len(clean_data)} rows)") return {**state, "data": clean_data, "error": None, "iteration_count": int(state.get("iteration_count", 0) or 0) + 1} except Exception as e: log_debug(f"SQL Error: {e}") log_node_duration("⏱️ QUERY EXECUTION", start, "(ERROR)") return {**state, "error": str(e), "iteration_count": int(state.get("iteration_count", 0) or 0) + 1} def validate_data_node(state: WorkflowState) -> WorkflowState: log_debug("NODE: Data Quality Analysis") warnings = [] if not state.get("data"): warnings.append("No data returned for your query.") return {**state, "data_quality": {"warnings": warnings}} all_zeros = True for row in state["data"]: for k, v in row.items(): if isinstance(v, (int, float)) and v > 0: all_zeros = False break if not all_zeros: break if all_zeros: warnings.append("All metrics returned zero values.") return {**state, "data_quality": {"warnings": warnings}} def analyze_results_node(state: WorkflowState) -> WorkflowState: log_debug("NODE: Results Analysis") data = state.get("data", []) if not data or state.get("error"): return {**state, "analysis": "No data found to analyze."} row_count = len(data) plan = state.get("query_plan", {}) metric = plan.get("y_axis") dimension = plan.get("x_axis") if dimension and metric and row_count > 0: top = data[0] analysis = f"Analysis of {row_count} records shows that {dimension} '{(top.get(dimension, 'Unknown'))}' is the top entry with {metric} of {top.get(metric, 0)}." elif metric and row_count > 0: top = data[0] analysis = f"Analysis completed. The total {metric} is {top.get(metric, 0)}." else: analysis = f"Retrieved {row_count} records for the requested analysis." return {**state, "analysis": analysis} def insights_node(state: WorkflowState) -> WorkflowState: log_debug("NODE: Insights Generation") print(f"\n💡 NODE: INSIGHTS") start = time.perf_counter() data = cast(list, state.get("data", [])) if not data or state.get("error"): return {**state, "insights": []} insights = [] plan = state.get("query_plan", {}) metric = plan.get("y_axis") dimension = plan.get("x_axis") if metric and dimension and len(data) > 0: top = data[0] insights.append(f"Top performer is {top.get(dimension, 'Unknown')} with {top.get(metric, 0)} {metric}.") if len(data) > 1: bottom = data[-1] insights.append(f"Lowest performer is {bottom.get(dimension, 'Unknown')} with {bottom.get(metric, 0)} {metric}.") try: total = sum(float(row.get(metric, 0)) for row in data if row.get(metric) is not None) avg_val = total / len(data) insights.append(f"The average {metric} across these records is {avg_val:.2f}.") except: insights.append("Data volume supports high-confidence trend analysis.") elif metric and len(data) > 0: top = data[0] insights.append(f"The overall {metric} is {top.get(metric, 0)}.") else: insights = ["Data supports standard trend analysis."] if state.get("data_quality") and state["data_quality"].get("warnings"): for w in state["data_quality"]["warnings"]: insights.append(f"⚠ {w}") log_node_duration("⏱️ Insights", start) return {**state, "insights": insights} def select_visualization_node(state: WorkflowState) -> WorkflowState: log_debug("NODE: Visualization Selection") print(f"\n📊 NODE: VISUALIZATION") start = time.perf_counter() data = state.get("data", []) if not data or state.get("error"): return {**state, "final_response": { "analysis": state.get("error", "No data available to plot a chart."), "charts": [] }} plan = state.get("query_plan", {}) intent = plan.get("intent", "kpi") chart_type = plan.get("chart_type", "kpi") x_axis = plan.get("x_axis") y_axis = plan.get("y_axis") if intent and "table" in intent.lower(): chart_type = "table" if intent and "list" in intent.lower(): chart_type = "table" if chart_type == "kpi" and len(data) == 1: data = [{**data[0], "summary": "Current Performance"}] if not x_axis: x_axis = "summary" if chart_type == "kpi" and len(data) > 1: chart_type = "bar" chart = { "type": chart_type, "title": f"{str(intent).replace('_', ' ').title()} Analysis" if intent else "Data Visualization", "x_axis": x_axis, "y_axis": y_axis, "data": data } ai_decision = { "intent": intent, "chart_type": chart_type, "x_axis": x_axis, "y_axis": y_axis, "reason": "Directly mapped from LLM Engine structured plan.", } final_response = { "analysis": state["analysis"], "ai_decision": ai_decision, "insights": state.get("insights", []), "generated_sql": state.get("sql_query", ""), "query_plan": plan, "data_quality": state.get("data_quality", {"warnings": []}), "charts": [chart], "explanation": state["analysis"] } log_node_duration("⏱️ Visualization", start) return {**state, "final_response": final_response} def should_retry(state: WorkflowState): if state.get("error") and state.get("iteration_count", 0) < 2: return "retry" return "continue" # --------------------------------------------------- # LANGGRAPH WORKFLOW # --------------------------------------------------- workflow = StateGraph(WorkflowState) workflow.add_node("planner", query_planner_node) workflow.add_node("exec", execute_sql_node) workflow.add_node("validate", validate_data_node) workflow.add_node("analysis", analyze_results_node) workflow.add_node("insights", insights_node) workflow.add_node("viz", select_visualization_node) workflow.set_entry_point("planner") workflow.add_edge("planner", "exec") # Conditional edge for retry workflow.add_conditional_edges( "exec", should_retry, { "retry": "planner", # Re-run planner if error "continue": "validate" } ) workflow.add_edge("validate", "analysis") workflow.add_edge("analysis", "insights") workflow.add_edge("insights", "viz") workflow.add_edge("viz", END) graph = workflow.compile() # --------------------------------------------------- # API # --------------------------------------------------- @app.post("/chat") def chat(req: ChatRequest): print(f"\n🚀 NEW REQUEST: {req.message}") total_start = time.perf_counter() try: initial_state = { "question": req.message, "query_plan": {}, "sql_query": "", "data": [], "analysis": "", "insights": [], "data_quality": {"warnings": []}, "final_response": {}, "error": None, "iteration_count": 0 } result = graph.invoke(initial_state) total_duration = time.perf_counter() - total_start print(f"\n[✅ TOTAL TIME] {total_duration:.2f}s") return result["final_response"] except Exception as e: log_debug(f"Graph overall error: {e}") return { "analysis": f"Error: {str(e)}", "charts": [] } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)