initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1 @@
"""Tools for interacting with a PowerBI dataset."""

View File

@@ -0,0 +1,70 @@
# flake8: noqa
QUESTION_TO_QUERY_BASE = """
Answer the question below with a DAX query that can be sent to Power BI. DAX queries have a simple syntax comprised of just one required keyword, EVALUATE, and several optional keywords: ORDER BY, START AT, DEFINE, MEASURE, VAR, TABLE, and COLUMN. Each keyword defines a statement used for the duration of the query. Any time < or > are used in the text below it means that those values need to be replaced by table, columns or other things. If the question is not something you can answer with a DAX query, reply with "I cannot answer this" and the question will be escalated to a human.
Some DAX functions return a table instead of a scalar, and must be wrapped in a function that evaluates the table and returns a scalar; unless the table is a single column, single row table, then it is treated as a scalar value. Most DAX functions require one or more arguments, which can include tables, columns, expressions, and values. However, some functions, such as PI, do not require any arguments, but always require parentheses to indicate the null argument. For example, you must always type PI(), not PI. You can also nest functions within other functions.
Some commonly used functions are:
EVALUATE <table> - At the most basic level, a DAX query is an EVALUATE statement containing a table expression. At least one EVALUATE statement is required, however, a query can contain any number of EVALUATE statements.
EVALUATE <table> ORDER BY <expression> ASC or DESC - The optional ORDER BY keyword defines one or more expressions used to sort query results. Any expression that can be evaluated for each row of the result is valid.
EVALUATE <table> ORDER BY <expression> ASC or DESC START AT <value> or <parameter> - The optional START AT keyword is used inside an ORDER BY clause. It defines the value at which the query results begin.
DEFINE MEASURE | VAR; EVALUATE <table> - The optional DEFINE keyword introduces one or more calculated entity definitions that exist only for the duration of the query. Definitions precede the EVALUATE statement and are valid for all EVALUATE statements in the query. Definitions can be variables, measures, tables1, and columns1. Definitions can reference other definitions that appear before or after the current definition. At least one definition is required if the DEFINE keyword is included in a query.
MEASURE <table name>[<measure name>] = <scalar expression> - Introduces a measure definition in a DEFINE statement of a DAX query.
VAR <name> = <expression> - Stores the result of an expression as a named variable, which can then be passed as an argument to other measure expressions. Once resultant values have been calculated for a variable expression, those values do not change, even if the variable is referenced in another expression.
FILTER(<table>,<filter>) - Returns a table that represents a subset of another table or expression, where <filter> is a Boolean expression that is to be evaluated for each row of the table. For example, [Amount] > 0 or [Region] = "France"
ROW(<name>, <expression>) - Returns a table with a single row containing values that result from the expressions given to each column.
TOPN(<n>, <table>, <OrderBy_Expression>, <Order>) - Returns a table with the top n rows from the specified table, sorted by the specified expression, in the order specified by 0 for descending, 1 for ascending, the default is 0. Multiple OrderBy_Expressions and Order pairs can be given, separated by a comma.
DISTINCT(<column>) - Returns a one-column table that contains the distinct values from the specified column. In other words, duplicate values are removed and only unique values are returned. This function cannot be used to Return values into a cell or column on a worksheet; rather, you nest the DISTINCT function within a formula, to get a list of distinct values that can be passed to another function and then counted, summed, or used for other operations.
DISTINCT(<table>) - Returns a table by removing duplicate rows from another table or expression.
Aggregation functions, names with a A in it, handle booleans and empty strings in appropriate ways, while the same function without A only uses the numeric values in a column. Functions names with an X in it can include a expression as an argument, this will be evaluated for each row in the table and the result will be used in the regular function calculation, these are the functions:
COUNT(<column>), COUNTA(<column>), COUNTX(<table>,<expression>), COUNTAX(<table>,<expression>), COUNTROWS([<table>]), COUNTBLANK(<column>), DISTINCTCOUNT(<column>), DISTINCTCOUNTNOBLANK (<column>) - these are all variations of count functions.
AVERAGE(<column>), AVERAGEA(<column>), AVERAGEX(<table>,<expression>) - these are all variations of average functions.
MAX(<column>), MAXA(<column>), MAXX(<table>,<expression>) - these are all variations of max functions.
MIN(<column>), MINA(<column>), MINX(<table>,<expression>) - these are all variations of min functions.
PRODUCT(<column>), PRODUCTX(<table>,<expression>) - these are all variations of product functions.
SUM(<column>), SUMX(<table>,<expression>) - these are all variations of sum functions.
Date and time functions:
DATE(year, month, day) - Returns a date value that represents the specified year, month, and day.
DATEDIFF(date1, date2, <interval>) - Returns the difference between two date values, in the specified interval, that can be SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR.
DATEVALUE(<date_text>) - Returns a date value that represents the specified date.
YEAR(<date>), QUARTER(<date>), MONTH(<date>), DAY(<date>), HOUR(<date>), MINUTE(<date>), SECOND(<date>) - Returns the part of the date for the specified date.
Finally, make sure to escape double quotes with a single backslash, and make sure that only table names have single quotes around them, while names of measures or the values of columns that you want to compare against are in escaped double quotes. Newlines are not necessary and can be skipped. The queries are serialized as json and so will have to fit be compliant with json syntax. Sometimes you will get a question, a DAX query and a error, in that case you need to rewrite the DAX query to get the correct answer.
The following tables exist: {tables}
and the schema's for some are given here:
{schemas}
Examples:
{examples}
"""
USER_INPUT = """
Question: {tool_input}
DAX:
"""
SINGLE_QUESTION_TO_QUERY = f"{QUESTION_TO_QUERY_BASE}{USER_INPUT}"
DEFAULT_FEWSHOT_EXAMPLES = """
Question: How many rows are in the table <table>?
DAX: EVALUATE ROW(\"Number of rows\", COUNTROWS(<table>))
----
Question: How many rows are in the table <table> where <column> is not empty?
DAX: EVALUATE ROW(\"Number of rows\", COUNTROWS(FILTER(<table>, <table>[<column>] <> \"\")))
----
Question: What was the average of <column> in <table>?
DAX: EVALUATE ROW(\"Average\", AVERAGE(<table>[<column>]))
----
"""
RETRY_RESPONSE = (
"{tool_input} DAX: {query} Error: {error}. Please supply a new DAX query."
)
BAD_REQUEST_RESPONSE = "Error on this question, the error was {error}, you can try to rephrase the question."
SCHEMA_ERROR_RESPONSE = "Bad request, are you sure the table name is correct?"
UNAUTHORIZED_RESPONSE = "Unauthorized. Try changing your authentication, do not retry."

View File

@@ -0,0 +1,276 @@
"""Tools for interacting with a Power BI dataset."""
import logging
from time import perf_counter
from typing import Any, Dict, Optional, Tuple
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
from pydantic import ConfigDict, Field, model_validator
from langchain_community.chat_models.openai import _import_tiktoken
from langchain_community.tools.powerbi.prompt import (
BAD_REQUEST_RESPONSE,
DEFAULT_FEWSHOT_EXAMPLES,
RETRY_RESPONSE,
)
from langchain_community.utilities.powerbi import PowerBIDataset, json_to_md
logger = logging.getLogger(__name__)
class QueryPowerBITool(BaseTool):
"""Tool for querying a Power BI Dataset."""
name: str = "query_powerbi"
description: str = """
Input to this tool is a detailed question about the dataset, output is a result from the dataset. It will try to answer the question using the dataset, and if it cannot, it will ask for clarification.
Example Input: "How many rows are in table1?"
""" # noqa: E501
llm_chain: Any = None
powerbi: PowerBIDataset = Field(exclude=True)
examples: Optional[str] = DEFAULT_FEWSHOT_EXAMPLES
session_cache: Dict[str, Any] = Field(default_factory=dict, exclude=True)
max_iterations: int = 5
output_token_limit: int = 4000
tiktoken_model_name: Optional[str] = None # "cl100k_base"
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@model_validator(mode="before")
@classmethod
def validate_llm_chain_input_variables( # pylint: disable=E0213
cls, values: dict
) -> dict:
"""Make sure the LLM chain has the correct input variables."""
llm_chain = values["llm_chain"]
for var in llm_chain.prompt.input_variables:
if var not in ["tool_input", "tables", "schemas", "examples"]:
raise ValueError(
"LLM chain for QueryPowerBITool must have input variables ['tool_input', 'tables', 'schemas', 'examples'], found %s", # noqa: E501 # pylint: disable=C0301
llm_chain.prompt.input_variables,
)
return values
def _check_cache(self, tool_input: str) -> Optional[str]:
"""Check if the input is present in the cache.
If the value is a bad request, overwrite with the escalated version,
if not present return None."""
if tool_input not in self.session_cache:
return None
return self.session_cache[tool_input]
def _run(
self,
tool_input: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
**kwargs: Any,
) -> str:
"""Execute the query, return the results or an error message."""
if cache := self._check_cache(tool_input):
logger.debug("Found cached result for %s: %s", tool_input, cache)
return cache
try:
logger.info("Running PBI Query Tool with input: %s", tool_input)
query = self.llm_chain.predict(
tool_input=tool_input,
tables=self.powerbi.get_table_names(),
schemas=self.powerbi.get_schemas(),
examples=self.examples,
callbacks=run_manager.get_child() if run_manager else None,
)
except Exception as exc: # pylint: disable=broad-except
self.session_cache[tool_input] = f"Error on call to LLM: {exc}"
return self.session_cache[tool_input]
if query == "I cannot answer this":
self.session_cache[tool_input] = query
return self.session_cache[tool_input]
logger.info("PBI Query:\n%s", query)
start_time = perf_counter()
pbi_result = self.powerbi.run(command=query)
end_time = perf_counter()
logger.debug("PBI Result: %s", pbi_result)
logger.debug(f"PBI Query duration: {end_time - start_time:0.6f}")
result, error = self._parse_output(pbi_result)
if error is not None and "TokenExpired" in error:
self.session_cache[tool_input] = (
"Authentication token expired or invalid, please try reauthenticate."
)
return self.session_cache[tool_input]
iterations = kwargs.get("iterations", 0)
if error and iterations < self.max_iterations:
return self._run(
tool_input=RETRY_RESPONSE.format(
tool_input=tool_input, query=query, error=error
),
run_manager=run_manager,
iterations=iterations + 1,
)
self.session_cache[tool_input] = (
result if result else BAD_REQUEST_RESPONSE.format(error=error)
)
return self.session_cache[tool_input]
async def _arun(
self,
tool_input: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
**kwargs: Any,
) -> str:
"""Execute the query, return the results or an error message."""
if cache := self._check_cache(tool_input):
logger.debug("Found cached result for %s: %s", tool_input, cache)
return f"{cache}, from cache, you have already asked this question."
try:
logger.info("Running PBI Query Tool with input: %s", tool_input)
query = await self.llm_chain.apredict(
tool_input=tool_input,
tables=self.powerbi.get_table_names(),
schemas=self.powerbi.get_schemas(),
examples=self.examples,
callbacks=run_manager.get_child() if run_manager else None,
)
except Exception as exc: # pylint: disable=broad-except
self.session_cache[tool_input] = f"Error on call to LLM: {exc}"
return self.session_cache[tool_input]
if query == "I cannot answer this":
self.session_cache[tool_input] = query
return self.session_cache[tool_input]
logger.info("PBI Query: %s", query)
start_time = perf_counter()
pbi_result = await self.powerbi.arun(command=query)
end_time = perf_counter()
logger.debug("PBI Result: %s", pbi_result)
logger.debug(f"PBI Query duration: {end_time - start_time:0.6f}")
result, error = self._parse_output(pbi_result)
if error is not None and ("TokenExpired" in error or "TokenError" in error):
self.session_cache[tool_input] = (
"Authentication token expired or invalid, please try to reauthenticate or check the scope of the credential." # noqa: E501
)
return self.session_cache[tool_input]
iterations = kwargs.get("iterations", 0)
if error and iterations < self.max_iterations:
return await self._arun(
tool_input=RETRY_RESPONSE.format(
tool_input=tool_input, query=query, error=error
),
run_manager=run_manager,
iterations=iterations + 1,
)
self.session_cache[tool_input] = (
result if result else BAD_REQUEST_RESPONSE.format(error=error)
)
return self.session_cache[tool_input]
def _parse_output(
self, pbi_result: Dict[str, Any]
) -> Tuple[Optional[str], Optional[Any]]:
"""Parse the output of the query to a markdown table."""
if "results" in pbi_result:
rows = pbi_result["results"][0]["tables"][0]["rows"]
if len(rows) == 0:
logger.info("0 records in result, query was valid.")
return (
None,
"0 rows returned, this might be correct, but please validate if all filter values were correct?", # noqa: E501
)
result = json_to_md(rows)
too_long, length = self._result_too_large(result)
if too_long:
return (
f"Result too large, please try to be more specific or use the `TOPN` function. The result is {length} tokens long, the limit is {self.output_token_limit} tokens.", # noqa: E501
None,
)
return result, None
if "error" in pbi_result:
if (
"pbi.error" in pbi_result["error"]
and "details" in pbi_result["error"]["pbi.error"]
):
return None, pbi_result["error"]["pbi.error"]["details"][0]["detail"]
return None, pbi_result["error"]
return None, pbi_result
def _result_too_large(self, result: str) -> Tuple[bool, int]:
"""Tokenize the output of the query."""
if self.tiktoken_model_name:
tiktoken_ = _import_tiktoken()
encoding = tiktoken_.encoding_for_model(self.tiktoken_model_name)
length = len(encoding.encode(result))
logger.info("Result length: %s", length)
return length > self.output_token_limit, length
return False, 0
class InfoPowerBITool(BaseTool):
"""Tool for getting metadata about a PowerBI Dataset."""
name: str = "schema_powerbi"
description: str = """
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
Be sure that the tables actually exist by calling list_tables_powerbi first!
Example Input: "table1, table2, table3"
""" # noqa: E501
powerbi: PowerBIDataset = Field(exclude=True)
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def _run(
self,
tool_input: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for tables in a comma-separated list."""
return self.powerbi.get_table_info(tool_input.split(", "))
async def _arun(
self,
tool_input: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
return await self.powerbi.aget_table_info(tool_input.split(", "))
class ListPowerBITool(BaseTool):
"""Tool for getting tables names."""
name: str = "list_tables_powerbi"
description: str = "Input is an empty string, output is a comma separated list of tables in the database." # noqa: E501 # pylint: disable=C0301
powerbi: PowerBIDataset = Field(exclude=True)
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def _run(
self,
tool_input: Optional[str] = None,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the names of the tables."""
return ", ".join(self.powerbi.get_table_names())
async def _arun(
self,
tool_input: Optional[str] = None,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
"""Get the names of the tables."""
return ", ".join(self.powerbi.get_table_names())