initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
from langchain_community.tools.databricks.tool import UCFunctionToolkit
__all__ = ["UCFunctionToolkit"]

View File

@@ -0,0 +1,254 @@
import inspect
import json
import logging
import os
import time
from dataclasses import dataclass
from io import StringIO
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
if TYPE_CHECKING:
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.catalog import FunctionInfo
from databricks.sdk.service.sql import StatementParameterListItem, StatementState
EXECUTE_FUNCTION_ARG_NAME = "__execution_args__"
DEFAULT_EXECUTE_FUNCTION_ARGS = {
"wait_timeout": "30s",
"row_limit": 100,
"byte_limit": 4096,
}
UC_TOOL_CLIENT_EXECUTION_TIMEOUT = "UC_TOOL_CLIENT_EXECUTION_TIMEOUT"
DEFAULT_UC_TOOL_CLIENT_EXECUTION_TIMEOUT = "120"
_logger = logging.getLogger(__name__)
def is_scalar(function: "FunctionInfo") -> bool:
from databricks.sdk.service.catalog import ColumnTypeName
return function.data_type != ColumnTypeName.TABLE_TYPE
@dataclass
class ParameterizedStatement:
statement: str
parameters: List["StatementParameterListItem"]
@dataclass
class FunctionExecutionResult:
"""
Result of executing a function.
We always use a string to present the result value for AI model to consume.
"""
error: Optional[str] = None
format: Optional[Literal["SCALAR", "CSV"]] = None
value: Optional[str] = None
truncated: Optional[bool] = None
def to_json(self) -> str:
data = {k: v for (k, v) in self.__dict__.items() if v is not None}
return json.dumps(data)
def get_execute_function_sql_stmt(
function: "FunctionInfo", json_params: Dict[str, Any]
) -> ParameterizedStatement:
from databricks.sdk.service.catalog import ColumnTypeName
from databricks.sdk.service.sql import StatementParameterListItem
parts = []
output_params = []
if is_scalar(function):
# TODO: IDENTIFIER(:function) did not work
parts.append(f"SELECT {function.full_name}(")
else:
parts.append(f"SELECT * FROM {function.full_name}(")
if function.input_params is None or function.input_params.parameters is None:
assert not json_params, (
"Function has no parameters but parameters were provided."
)
else:
args = []
use_named_args = False
for p in function.input_params.parameters:
if p.name not in json_params:
if p.parameter_default is not None:
use_named_args = True
else:
raise ValueError(
f"Parameter {p.name} is required but not provided."
)
else:
arg_clause = ""
if use_named_args:
arg_clause += f"{p.name} => "
json_value = json_params[p.name]
if p.type_name in (
ColumnTypeName.ARRAY,
ColumnTypeName.MAP,
ColumnTypeName.STRUCT,
):
# Use from_json to restore values of complex types.
json_value_str = json.dumps(json_value)
# TODO: parametrize type
arg_clause += f"from_json(:{p.name}, '{p.type_text}')"
output_params.append(
StatementParameterListItem(name=p.name, value=json_value_str)
)
elif p.type_name == ColumnTypeName.BINARY:
# Use ubbase64 to restore binary values.
arg_clause += f"unbase64(:{p.name})"
output_params.append(
StatementParameterListItem(name=p.name, value=json_value)
)
else:
arg_clause += f":{p.name}"
output_params.append(
StatementParameterListItem(
name=p.name, value=json_value, type=p.type_text
)
)
args.append(arg_clause)
parts.append(",".join(args))
parts.append(")")
# TODO: check extra params in kwargs
statement = "".join(parts)
return ParameterizedStatement(statement=statement, parameters=output_params)
def execute_function(
ws: "WorkspaceClient",
warehouse_id: str,
function: "FunctionInfo",
parameters: Dict[str, Any],
) -> FunctionExecutionResult:
"""
Execute a function with the given arguments and return the result.
"""
try:
import pandas as pd
except ImportError as e:
raise ImportError(
"Could not import pandas python package. "
"Please install it with `pip install pandas`."
) from e
from databricks.sdk.service.sql import StatementState
if (
function.input_params
and function.input_params.parameters
and any(
p.name == EXECUTE_FUNCTION_ARG_NAME
for p in function.input_params.parameters
)
):
raise ValueError(
"Parameter name conflicts with the reserved argument name for executing "
f"functions: {EXECUTE_FUNCTION_ARG_NAME}. "
f"Please rename the parameter {EXECUTE_FUNCTION_ARG_NAME}."
)
# avoid modifying the original dict
execute_statement_args = {**DEFAULT_EXECUTE_FUNCTION_ARGS}
allowed_execute_statement_args = inspect.signature(
ws.statement_execution.execute_statement
).parameters
if not any(
p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
for p in allowed_execute_statement_args.values()
):
invalid_params = set()
passed_execute_statement_args = parameters.pop(EXECUTE_FUNCTION_ARG_NAME, {})
for k, v in passed_execute_statement_args.items():
if k in allowed_execute_statement_args:
execute_statement_args[k] = v
else:
invalid_params.add(k)
if invalid_params:
raise ValueError(
f"Invalid parameters for executing functions: {invalid_params}. "
f"Allowed parameters are: {allowed_execute_statement_args.keys()}."
)
# TODO: async so we can run functions in parallel
parametrized_statement = get_execute_function_sql_stmt(function, parameters)
response = ws.statement_execution.execute_statement(
statement=parametrized_statement.statement,
warehouse_id=warehouse_id,
parameters=parametrized_statement.parameters,
**execute_statement_args,
)
if response.status and job_pending(response.status.state) and response.statement_id:
statement_id = response.statement_id
wait_time = 0
retry_cnt = 0
client_execution_timeout = int(
os.environ.get(
UC_TOOL_CLIENT_EXECUTION_TIMEOUT,
DEFAULT_UC_TOOL_CLIENT_EXECUTION_TIMEOUT,
)
)
while wait_time < client_execution_timeout:
wait = min(2**retry_cnt, client_execution_timeout - wait_time)
_logger.debug(
f"Retrying {retry_cnt} time to get statement execution "
f"status after {wait} seconds."
)
time.sleep(wait)
response = ws.statement_execution.get_statement(statement_id)
if response.status is None or not job_pending(response.status.state):
break
wait_time += wait
retry_cnt += 1
if response.status and job_pending(response.status.state):
return FunctionExecutionResult(
error=f"Statement execution is still pending after {wait_time} "
"seconds. Please increase the wait_timeout argument for executing "
f"the function or increase {UC_TOOL_CLIENT_EXECUTION_TIMEOUT} "
"environment variable for increasing retrying time, default is "
f"{DEFAULT_UC_TOOL_CLIENT_EXECUTION_TIMEOUT} seconds."
)
assert response.status is not None, f"Statement execution failed: {response}"
if response.status.state != StatementState.SUCCEEDED:
error = response.status.error
assert error is not None, (
f"Statement execution failed but no error message was provided: {response}"
)
return FunctionExecutionResult(error=f"{error.error_code}: {error.message}")
manifest = response.manifest
assert manifest is not None
truncated = manifest.truncated
result = response.result
assert result is not None, (
"Statement execution succeeded but no result was provided."
)
data_array = result.data_array
if is_scalar(function):
value = None
if data_array and len(data_array) > 0 and len(data_array[0]) > 0:
value = str(data_array[0][0])
return FunctionExecutionResult(
format="SCALAR", value=value, truncated=truncated
)
else:
schema = manifest.schema
assert schema is not None and schema.columns is not None, (
"Statement execution succeeded but no schema was provided."
)
columns = [c.name for c in schema.columns]
if data_array is None:
data_array = []
pdf = pd.DataFrame.from_records(data_array, columns=columns)
csv_buffer = StringIO()
pdf.to_csv(csv_buffer, index=False)
return FunctionExecutionResult(
format="CSV", value=csv_buffer.getvalue(), truncated=truncated
)
def job_pending(state: Optional["StatementState"]) -> bool:
from databricks.sdk.service.sql import StatementState
return state in (StatementState.PENDING, StatementState.RUNNING)

View File

@@ -0,0 +1,210 @@
import json
from datetime import date, datetime
from decimal import Decimal
from hashlib import md5
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
from langchain_core._api import deprecated
from langchain_core.tools import BaseTool, StructuredTool
from langchain_core.tools.base import BaseToolkit
from pydantic import BaseModel, Field, create_model
from typing_extensions import Self
if TYPE_CHECKING:
from databricks.sdk.service.catalog import FunctionInfo
from pydantic import ConfigDict
from langchain_community.tools.databricks._execution import execute_function
def _uc_type_to_pydantic_type(uc_type_json: Union[str, Dict[str, Any]]) -> Type:
mapping = {
"long": int,
"binary": bytes,
"boolean": bool,
"date": date,
"double": float,
"float": float,
"integer": int,
"short": int,
"string": str,
"timestamp": datetime,
"timestamp_ntz": datetime,
"byte": int,
}
if isinstance(uc_type_json, str):
if uc_type_json in mapping:
return mapping[uc_type_json]
else:
if uc_type_json.startswith("decimal"):
return Decimal
elif uc_type_json == "void" or uc_type_json.startswith("interval"):
raise TypeError(f"Type {uc_type_json} is not supported.")
else:
raise TypeError(
f"Unknown type {uc_type_json}. Try upgrading this package."
)
else:
assert isinstance(uc_type_json, dict)
tpe = uc_type_json["type"]
if tpe == "array":
element_type = _uc_type_to_pydantic_type(uc_type_json["elementType"])
if uc_type_json["containsNull"]:
element_type = Optional[element_type] # type: ignore[assignment]
return List[element_type] # type: ignore[valid-type]
elif tpe == "map":
key_type = uc_type_json["keyType"]
assert key_type == "string", TypeError(
f"Only support STRING key type for MAP but got {key_type}."
)
value_type = _uc_type_to_pydantic_type(uc_type_json["valueType"])
if uc_type_json["valueContainsNull"]:
value_type: Type = Optional[value_type] # type: ignore[no-redef]
return Dict[str, value_type] # type: ignore[valid-type]
elif tpe == "struct":
fields = {}
for field in uc_type_json["fields"]:
field_type = _uc_type_to_pydantic_type(field["type"])
if field.get("nullable"):
field_type = Optional[field_type] # type: ignore[assignment]
comment = (
uc_type_json["metadata"].get("comment")
if "metadata" in uc_type_json
else None
)
fields[field["name"]] = (field_type, Field(..., description=comment))
uc_type_json_str = json.dumps(uc_type_json, sort_keys=True)
type_hash = md5(uc_type_json_str.encode()).hexdigest()[:8]
return create_model(f"Struct_{type_hash}", **fields) # type: ignore[call-overload]
else:
raise TypeError(f"Unknown type {uc_type_json}. Try upgrading this package.")
def _generate_args_schema(function: "FunctionInfo") -> Type[BaseModel]:
if function.input_params is None:
return BaseModel
params = function.input_params.parameters
assert params is not None
fields = {}
for p in params:
assert p.type_json is not None
type_json = json.loads(p.type_json)["type"]
pydantic_type = _uc_type_to_pydantic_type(type_json)
description = p.comment
default: Any = ...
if p.parameter_default:
pydantic_type = Optional[pydantic_type] # type: ignore[assignment]
default = None
# TODO: Convert default value string to the correct type.
# We might need to use statement execution API
# to get the JSON representation of the value.
default_description = f"(Default: {p.parameter_default})"
if description:
description += f" {default_description}"
else:
description = default_description
fields[p.name] = (
pydantic_type,
Field(default=default, description=description),
)
return create_model( # type: ignore[call-overload]
f"{function.catalog_name}__{function.schema_name}__{function.name}__params",
**fields,
)
def _get_tool_name(function: "FunctionInfo") -> str:
tool_name = f"{function.catalog_name}__{function.schema_name}__{function.name}"[
-64:
]
return tool_name
def _get_default_workspace_client() -> Any:
try:
from databricks.sdk import WorkspaceClient
except ImportError as e:
raise ImportError(
"Could not import databricks-sdk python package. "
"Please install it with `pip install databricks-sdk`."
) from e
return WorkspaceClient()
@deprecated(
since="0.3.18",
removal="1.0",
alternative_import="databricks_langchain.uc_ai.UCFunctionToolkit",
)
class UCFunctionToolkit(BaseToolkit):
warehouse_id: str = Field(
description="The ID of a Databricks SQL Warehouse to execute functions."
)
workspace_client: Any = Field(
default_factory=_get_default_workspace_client,
description="Databricks workspace client.",
)
tools: Dict[str, BaseTool] = Field(default_factory=dict)
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def include(self, *function_names: str, **kwargs: Any) -> Self:
"""
Includes UC functions to the toolkit.
Args:
functions: A list of UC function names in the format
"catalog_name.schema_name.function_name" or
"catalog_name.schema_name.*".
If the function name ends with ".*",
all functions in the schema will be added.
kwargs: Extra arguments to pass to StructuredTool, e.g., `return_direct`.
"""
for name in function_names:
if name.endswith(".*"):
catalog_name, schema_name = name[:-2].split(".")
# TODO: handle pagination, warn and truncate if too many
functions = self.workspace_client.functions.list(
catalog_name=catalog_name, schema_name=schema_name
)
for f in functions:
assert f.full_name is not None
self.include(f.full_name, **kwargs)
else:
if name not in self.tools:
self.tools[name] = self._make_tool(name, **kwargs)
return self
def _make_tool(self, function_name: str, **kwargs: Any) -> BaseTool:
function = self.workspace_client.functions.get(function_name)
name = _get_tool_name(function)
description = function.comment or ""
args_schema = _generate_args_schema(function)
def func(*args: Any, **kwargs: Any) -> str:
# TODO: We expect all named args and ignore args.
# Non-empty args show up when the function has no parameters.
args_json = json.loads(json.dumps(kwargs, default=str))
result = execute_function(
ws=self.workspace_client,
warehouse_id=self.warehouse_id,
function=function,
parameters=args_json,
)
return result.to_json()
return StructuredTool(
name=name,
description=description,
args_schema=args_schema,
func=func,
**kwargs,
)
def get_tools(self) -> List[BaseTool]:
return list(self.tools.values())