initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

549
ai_service.py Normal file
View File

@@ -0,0 +1,549 @@
from fastapi import FastAPI
from pydantic import BaseModel
import json
import requests
import re
import traceback
import duckdb
import math
import os
from typing import Optional, Any, TypedDict, List, Dict, cast
from datetime import datetime, date
from decimal import Decimal
import time
# ---------------------------------------------------
# SCHEMAS (Pydantic for strict enforcement)
# ---------------------------------------------------
class ChartConfig(BaseModel):
type: str # e.g., "bar", "pie", "line", "area", "table", "kpi"
title: str
x_axis: Optional[str] = None
y_axis: Optional[str] = None
data: List[Dict[str, Any]]
class AnalyticsResponse(BaseModel):
analysis: str
charts: List[ChartConfig]
from fastapi.middleware.cors import CORSMiddleware
from langgraph.graph import StateGraph, END
# ---------------------------------------------------
# CONFIG
# ---------------------------------------------------
OLLAMA_URL = "http://localhost:11434/api/chat"
MODEL = "qwen2.5:3b"
# DuckDB should connect to a local persistent file
DUCKDB_PATH = "analytics.duckdb"
DEBUG_LOG = "debug.log"
# ---------------------------------------------------
# FASTAPI
# ---------------------------------------------------
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------
# DEBUG LOGGER
# ---------------------------------------------------
def log_debug(msg: str):
with open(DEBUG_LOG, "a", encoding="utf-8") as f:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
f.write(f"[{timestamp}] {msg}\n")
def log_node_duration(label: str, start_time: float, extra: str = ""):
duration = time.perf_counter() - start_time
print(f"[{label}] {duration:.2f}s {extra}")
# ---------------------------------------------------
# DUCKDB INIT
# ---------------------------------------------------
def init_duckdb(conn):
try:
log_debug("Initializing DuckDB with live S3 data")
# Load httpfs extension for S3 access
conn.execute("INSTALL httpfs; LOAD httpfs;")
# S3 configuration matching the environment
conn.execute("SET s3_region='sgp1';")
conn.execute("SET s3_endpoint='sgp1.digitaloceanspaces.com';")
conn.execute("SET s3_url_style='path';")
# Create a view from the S3 parquet files
s3_path = 's3://nearle/parquet/deliveries/*.parquet'
conn.execute(f"""
CREATE OR REPLACE VIEW deliveries AS
SELECT * FROM read_parquet('{s3_path}', union_by_name = true)
""")
log_debug(f"DuckDB ready with view 'deliveries' from {s3_path}")
except Exception as e:
log_debug(f"DuckDB init error: {e}")
pass
# ---------------------------------------------------
# REQUEST MODEL
# ---------------------------------------------------
class ChatRequest(BaseModel):
message: str
data: Optional[Any] = None
# ---------------------------------------------------
# LANGGRAPH STATE
# ---------------------------------------------------
class WorkflowState(TypedDict):
question: str
query_plan: Dict[str, Any]
sql_query: str
data: List[Dict[str, Any]]
analysis: str
insights: List[str]
data_quality: Dict[str, Any]
final_response: Dict[str, Any]
error: Optional[str]
iteration_count: int
# ---------------------------------------------------
# OLLAMA CALL
# ---------------------------------------------------
def call_ollama(system_prompt: str, user_prompt: str) -> str:
for attempt in range(3):
try:
payload = {
"model": MODEL,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
"stream": False,
"options": {
"temperature": 0,
"num_predict": 200, # 🔥 increased to prevent cutoff
}
}
r = requests.post(OLLAMA_URL, json=payload, timeout=20)
r.raise_for_status()
content = r.json().get("message", {}).get("content", "")
if content and content.strip() != "":
return content
log_debug(f"Ollama empty response on attempt {attempt + 1}")
except Exception as e:
log_debug(f"Ollama call error: {e} on attempt {attempt + 1}")
time.sleep(1)
return ""
# ---------------------------------------------------
# UTILITIES
# ---------------------------------------------------
def safe_parse_json(text: str) -> Dict[str, Any]:
"""Robustly parse JSON from LLM response."""
try:
return json.loads(text.strip())
except json.JSONDecodeError:
match = re.search(r"```json\s*(.*?)\s*```", text, re.S)
if match:
try:
return json.loads(match.group(1).strip())
except: pass
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1:
try:
return json.loads(text[start:end+1])
except: pass
raise ValueError("Could not find valid JSON in LLM response")
# ---------------------------------------------------
# NODES
# ---------------------------------------------------
def query_planner_node(state: WorkflowState) -> WorkflowState:
log_debug(f"NODE: Query Planner - {state['question']}")
print(f"\n🔍 NODE: PLANNER (Iteration: {state.get('iteration_count', 0)})")
start = time.perf_counter()
system = """You are an AI SQL Planner.
Return ONLY valid JSON. No explanation.
---
## DATABASE
Table: deliveries
Columns:
* orderid
* tenantname
* deliveryamt (can be TEXT or NUMERIC)
* deliverydate
---
## RULES
1. Revenue = SUM(CAST(deliveryamt AS NUMERIC))
2. Never use:
* orderamount
* unsafe filters
* TRIM unless needed
3. SQL must NEVER fail:
* Always CAST for SUM
* Keep query simple
4. If previous query failed:
* Use ERROR below to FIX SQL
* NEVER repeat same query
ERROR:
{error}
---
## OUTPUT FORMAT
You MUST return a JSON object containing EXACTLY these 6 keys:
{
"intent": "kpi | trend | comparison | table",
"sql": "valid SQL query",
"chart_type": "kpi | bar | line | pie | table",
"x_axis": "column or null",
"y_axis": "column or null",
"explanation": "short explanation"
}"""
error_msg = state.get("error", "")
system = system.replace("{error}", str(error_msg))
error_msg = state.get("error")
if error_msg and state.get("iteration_count", 0) > 0:
log_debug(f"Retrying SQL generation due to error: {error_msg}")
user_prompt = f"User Query:\n{state['question']}\n\nWarning: Your previous SQL query failed with this error in DuckDB:\n{error_msg}\n\nPlease fix the SQL query and return the correct JSON plan."
else:
user_prompt = f"User Query:\n{state['question']}"
result = call_ollama(system, user_prompt)
print("\n--- RAW LLM RESPONSE ---")
print(result)
print("------------------------\n")
print("RESPONSE LENGTH:", len(result))
if not result or result.strip() == "":
raise Exception("LLM returned empty response")
try:
plan = safe_parse_json(result)
# Strict validation
for key in ["intent", "sql", "chart_type", "y_axis"]:
if key not in plan:
raise ValueError(f"Missing required key: {key}")
if plan.get("intent") == "ranking" and "GROUP BY" not in str(plan.get("sql", "")).upper():
raise ValueError("Ranking query SQL must contain a GROUP BY clause")
print("\n--- PARSED PLAN ---")
print(json.dumps(plan, indent=2))
print("-------------------\n")
print("\n--- FINAL SQL USED ---")
print(plan["sql"])
print("----------------------\n")
except Exception as e:
log_debug(f"Planner JSON parsing error: {str(e)} for output: {result}")
print(f"\n❌ ERROR: Invalid LLM response. Original error: {str(e)}")
raise Exception("Invalid LLM response")
log_node_duration("⏱️ Planner", start)
return {**state, "query_plan": plan, "sql_query": plan.get("sql", ""), "error": None}
def execute_sql_node(state: WorkflowState) -> WorkflowState:
log_debug(f"NODE: Execute SQL - {state['sql_query']}")
print(f"\n🗄️ NODE: SQL EXECUTION")
start = time.perf_counter()
try:
conn = duckdb.connect(DUCKDB_PATH)
init_duckdb(conn)
df = conn.execute(state["sql_query"]).df()
sql_results = df.to_dict("records")
conn.close()
# Sanitize data for JSON response
clean_data = []
for r in sql_results:
row_clean = {}
for k, v in r.items():
if isinstance(v, Decimal): v = float(v)
if isinstance(v, float) and (math.isnan(v) or math.isinf(v)): v = 0.0
if isinstance(v, (datetime, date)):
# Format as clean YYYY-MM-DD to save chart axis space
v = v.strftime("%Y-%m-%d")
# Replace bad string primitives from DB with recognizable null equivalent
if v is None or (isinstance(v, str) and v.strip() == ""):
# Ignore for numeric indicators
if not ("orders" in k or "revenue" in k or "avg_order_value" in k or "charges" in k):
v = "Unknown"
row_clean[k] = v
clean_data.append(row_clean)
log_node_duration("⏱️ QUERY EXECUTION", start, f"({len(clean_data)} rows)")
return {**state, "data": clean_data, "error": None, "iteration_count": int(state.get("iteration_count", 0) or 0) + 1}
except Exception as e:
log_debug(f"SQL Error: {e}")
log_node_duration("⏱️ QUERY EXECUTION", start, "(ERROR)")
return {**state, "error": str(e), "iteration_count": int(state.get("iteration_count", 0) or 0) + 1}
def validate_data_node(state: WorkflowState) -> WorkflowState:
log_debug("NODE: Data Quality Analysis")
warnings = []
if not state.get("data"):
warnings.append("No data returned for your query.")
return {**state, "data_quality": {"warnings": warnings}}
all_zeros = True
for row in state["data"]:
for k, v in row.items():
if isinstance(v, (int, float)) and v > 0:
all_zeros = False
break
if not all_zeros: break
if all_zeros:
warnings.append("All metrics returned zero values.")
return {**state, "data_quality": {"warnings": warnings}}
def analyze_results_node(state: WorkflowState) -> WorkflowState:
log_debug("NODE: Results Analysis")
data = state.get("data", [])
if not data or state.get("error"):
return {**state, "analysis": "No data found to analyze."}
row_count = len(data)
plan = state.get("query_plan", {})
metric = plan.get("y_axis")
dimension = plan.get("x_axis")
if dimension and metric and row_count > 0:
top = data[0]
analysis = f"Analysis of {row_count} records shows that {dimension} '{(top.get(dimension, 'Unknown'))}' is the top entry with {metric} of {top.get(metric, 0)}."
elif metric and row_count > 0:
top = data[0]
analysis = f"Analysis completed. The total {metric} is {top.get(metric, 0)}."
else:
analysis = f"Retrieved {row_count} records for the requested analysis."
return {**state, "analysis": analysis}
def insights_node(state: WorkflowState) -> WorkflowState:
log_debug("NODE: Insights Generation")
print(f"\n💡 NODE: INSIGHTS")
start = time.perf_counter()
data = cast(list, state.get("data", []))
if not data or state.get("error"):
return {**state, "insights": []}
insights = []
plan = state.get("query_plan", {})
metric = plan.get("y_axis")
dimension = plan.get("x_axis")
if metric and dimension and len(data) > 0:
top = data[0]
insights.append(f"Top performer is {top.get(dimension, 'Unknown')} with {top.get(metric, 0)} {metric}.")
if len(data) > 1:
bottom = data[-1]
insights.append(f"Lowest performer is {bottom.get(dimension, 'Unknown')} with {bottom.get(metric, 0)} {metric}.")
try:
total = sum(float(row.get(metric, 0)) for row in data if row.get(metric) is not None)
avg_val = total / len(data)
insights.append(f"The average {metric} across these records is {avg_val:.2f}.")
except:
insights.append("Data volume supports high-confidence trend analysis.")
elif metric and len(data) > 0:
top = data[0]
insights.append(f"The overall {metric} is {top.get(metric, 0)}.")
else:
insights = ["Data supports standard trend analysis."]
if state.get("data_quality") and state["data_quality"].get("warnings"):
for w in state["data_quality"]["warnings"]:
insights.append(f"{w}")
log_node_duration("⏱️ Insights", start)
return {**state, "insights": insights}
def select_visualization_node(state: WorkflowState) -> WorkflowState:
log_debug("NODE: Visualization Selection")
print(f"\n📊 NODE: VISUALIZATION")
start = time.perf_counter()
data = state.get("data", [])
if not data or state.get("error"):
return {**state, "final_response": {
"analysis": state.get("error", "No data available to plot a chart."),
"charts": []
}}
plan = state.get("query_plan", {})
intent = plan.get("intent", "kpi")
chart_type = plan.get("chart_type", "kpi")
x_axis = plan.get("x_axis")
y_axis = plan.get("y_axis")
if intent and "table" in intent.lower():
chart_type = "table"
if intent and "list" in intent.lower():
chart_type = "table"
if chart_type == "kpi" and len(data) == 1:
data = [{**data[0], "summary": "Current Performance"}]
if not x_axis:
x_axis = "summary"
if chart_type == "kpi" and len(data) > 1:
chart_type = "bar"
chart = {
"type": chart_type,
"title": f"{str(intent).replace('_', ' ').title()} Analysis" if intent else "Data Visualization",
"x_axis": x_axis,
"y_axis": y_axis,
"data": data
}
ai_decision = {
"intent": intent,
"chart_type": chart_type,
"x_axis": x_axis,
"y_axis": y_axis,
"reason": "Directly mapped from LLM Engine structured plan.",
}
final_response = {
"analysis": state["analysis"],
"ai_decision": ai_decision,
"insights": state.get("insights", []),
"generated_sql": state.get("sql_query", ""),
"query_plan": plan,
"data_quality": state.get("data_quality", {"warnings": []}),
"charts": [chart],
"explanation": state["analysis"]
}
log_node_duration("⏱️ Visualization", start)
return {**state, "final_response": final_response}
def should_retry(state: WorkflowState):
if state.get("error") and state.get("iteration_count", 0) < 2:
return "retry"
return "continue"
# ---------------------------------------------------
# LANGGRAPH WORKFLOW
# ---------------------------------------------------
workflow = StateGraph(WorkflowState)
workflow.add_node("planner", query_planner_node)
workflow.add_node("exec", execute_sql_node)
workflow.add_node("validate", validate_data_node)
workflow.add_node("analysis", analyze_results_node)
workflow.add_node("insights", insights_node)
workflow.add_node("viz", select_visualization_node)
workflow.set_entry_point("planner")
workflow.add_edge("planner", "exec")
# Conditional edge for retry
workflow.add_conditional_edges(
"exec",
should_retry,
{
"retry": "planner", # Re-run planner if error
"continue": "validate"
}
)
workflow.add_edge("validate", "analysis")
workflow.add_edge("analysis", "insights")
workflow.add_edge("insights", "viz")
workflow.add_edge("viz", END)
graph = workflow.compile()
# ---------------------------------------------------
# API
# ---------------------------------------------------
@app.post("/chat")
def chat(req: ChatRequest):
print(f"\n🚀 NEW REQUEST: {req.message}")
total_start = time.perf_counter()
try:
initial_state = {
"question": req.message,
"query_plan": {},
"sql_query": "",
"data": [],
"analysis": "",
"insights": [],
"data_quality": {"warnings": []},
"final_response": {},
"error": None,
"iteration_count": 0
}
result = graph.invoke(initial_state)
total_duration = time.perf_counter() - total_start
print(f"\n[✅ TOTAL TIME] {total_duration:.2f}s")
return result["final_response"]
except Exception as e:
log_debug(f"Graph overall error: {e}")
return {
"analysis": f"Error: {str(e)}",
"charts": []
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)