initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,285 @@
from __future__ import annotations # noqa: D100
import contextlib
import datetime
import json
import typing
from decimal import Decimal
import polars as pl
from polars.io.plugins import register_io_source
import duckdb
if typing.TYPE_CHECKING:
from collections.abc import Iterator
import typing_extensions
_ExpressionTree: typing_extensions.TypeAlias = typing.Dict[str, typing.Union[str, int, "_ExpressionTree", typing.Any]] # noqa: UP006
def _predicate_to_expression(predicate: pl.Expr) -> duckdb.Expression | None:
"""Convert a Polars predicate expression to a DuckDB-compatible SQL expression.
Parameters:
predicate (pl.Expr): A Polars expression (e.g., col("foo") > 5)
Returns:
SQLExpression: A DuckDB SQL expression string equivalent.
None: If conversion fails.
Example:
>>> _predicate_to_expression(pl.col("foo") > 5)
SQLExpression("(foo > 5)")
"""
# Serialize the Polars expression tree to JSON
tree = json.loads(predicate.meta.serialize(format="json"))
try:
# Convert the tree to SQL
sql_filter = _pl_tree_to_sql(tree)
return duckdb.SQLExpression(sql_filter)
except Exception:
# If the conversion fails, we return None
return None
def _pl_operation_to_sql(op: str) -> str:
"""Map Polars binary operation strings to SQL equivalents.
Example:
>>> _pl_operation_to_sql("Eq")
'='
"""
try:
return {
"Lt": "<",
"LtEq": "<=",
"Gt": ">",
"GtEq": ">=",
"Eq": "=",
"Modulus": "%",
"And": "AND",
"Or": "OR",
}[op]
except KeyError:
raise NotImplementedError(op) # noqa: B904
def _escape_sql_identifier(identifier: str) -> str:
"""Escape SQL identifiers by doubling any double quotes and wrapping in double quotes.
Example:
>>> _escape_sql_identifier('column"name')
'"column""name"'
"""
escaped = identifier.replace('"', '""')
return f'"{escaped}"'
def _pl_tree_to_sql(tree: _ExpressionTree) -> str:
"""Recursively convert a Polars expression tree (as JSON) to a SQL string.
Parameters:
tree (dict): JSON-deserialized expression tree from Polars
Returns:
str: SQL expression string
Example:
Input tree:
{
"BinaryExpr": {
"left": { "Column": "foo" },
"op": "Gt",
"right": { "Literal": { "Int": 5 } }
}
}
Output: "(foo > 5)"
"""
[node_type] = tree.keys()
if node_type == "BinaryExpr":
# Binary expressions: left OP right
bin_expr_tree = tree[node_type]
assert isinstance(bin_expr_tree, dict), f"A {node_type} should be a dict but got {type(bin_expr_tree)}"
lhs, op, rhs = bin_expr_tree["left"], bin_expr_tree["op"], bin_expr_tree["right"]
assert isinstance(lhs, dict), f"LHS of a {node_type} should be a dict but got {type(lhs)}"
assert isinstance(op, str), f"The op of a {node_type} should be a str but got {type(op)}"
assert isinstance(rhs, dict), f"RHS of a {node_type} should be a dict but got {type(rhs)}"
return f"({_pl_tree_to_sql(lhs)} {_pl_operation_to_sql(op)} {_pl_tree_to_sql(rhs)})"
if node_type == "Column":
# A reference to a column name
# Wrap in quotes to handle special characters
col_name = tree[node_type]
assert isinstance(col_name, str), f"The col name of a {node_type} should be a str but got {type(col_name)}"
return _escape_sql_identifier(col_name)
if node_type in ("Literal", "Dyn"):
# Recursively process dynamic or literal values
val_tree = tree[node_type]
assert isinstance(val_tree, dict), f"A {node_type} should be a dict but got {type(val_tree)}"
return _pl_tree_to_sql(val_tree)
if node_type == "Int":
# Direct integer literals
int_literal = tree[node_type]
assert isinstance(int_literal, (int, str)), (
f"The value of an Int should be an int or str but got {type(int_literal)}"
)
return str(int_literal)
if node_type == "Function":
# Handle boolean functions like IsNull, IsNotNull
func_tree = tree[node_type]
assert isinstance(func_tree, dict), f"A {node_type} should be a dict but got {type(func_tree)}"
inputs = func_tree["input"]
assert isinstance(inputs, list), f"A {node_type} should have a list of dicts as input but got {type(inputs)}"
input_tree = inputs[0]
assert isinstance(input_tree, dict), (
f"A {node_type} should have a list of dicts as input but got {type(input_tree)}"
)
func_dict = func_tree["function"]
assert isinstance(func_dict, dict), (
f"A {node_type} should have a function dict as input but got {type(func_dict)}"
)
if "Boolean" in func_dict:
func = func_dict["Boolean"]
arg_sql = _pl_tree_to_sql(inputs[0])
if func == "IsNull":
return f"({arg_sql} IS NULL)"
if func == "IsNotNull":
return f"({arg_sql} IS NOT NULL)"
msg = f"Boolean function not supported: {func}"
raise NotImplementedError(msg)
msg = f"Unsupported function type: {func_dict}"
raise NotImplementedError(msg)
if node_type == "Scalar":
# Detect format: old style (dtype/value) or new style (direct type key)
scalar_tree = tree[node_type]
assert isinstance(scalar_tree, dict), f"A {node_type} should be a dict but got {type(scalar_tree)}"
if "dtype" in scalar_tree and "value" in scalar_tree:
dtype = str(scalar_tree["dtype"])
value = scalar_tree["value"]
else:
# New style: dtype is the single key in the dict
dtype = next(iter(scalar_tree.keys()))
value = scalar_tree
assert isinstance(dtype, str), f"A {node_type} should have a str dtype but got {type(dtype)}"
assert isinstance(value, dict), f"A {node_type} should have a dict value but got {type(value)}"
# Decimal support
if dtype.startswith("{'Decimal'") or dtype == "Decimal":
decimal_value = value["Decimal"]
assert isinstance(decimal_value, list), (
f"A {dtype} should be a two or three member list but got {type(decimal_value)}"
)
assert 2 <= len(decimal_value) <= 3, (
f"A {dtype} should be a two or three member list but got {len(decimal_value)} member list"
)
return str(Decimal(decimal_value[0]) / Decimal(10 ** decimal_value[-1]))
# Datetime with microseconds since epoch
if dtype.startswith("{'Datetime'") or dtype == "Datetime":
micros = value["Datetime"]
assert isinstance(micros, list), f"A {dtype} should be a one member list but got {type(micros)}"
dt_timestamp = datetime.datetime.fromtimestamp(micros[0] / 1_000_000, tz=datetime.timezone.utc)
return f"'{dt_timestamp!s}'::TIMESTAMP"
# Match simple numeric/boolean types
if dtype in (
"Int8",
"Int16",
"Int32",
"Int64",
"UInt8",
"UInt16",
"UInt32",
"UInt64",
"Float32",
"Float64",
"Boolean",
):
return str(value[dtype])
# Time type
if dtype == "Time":
nanoseconds = value["Time"]
assert isinstance(nanoseconds, int), f"A {dtype} should be an int but got {type(nanoseconds)}"
seconds = nanoseconds // 1_000_000_000
microseconds = (nanoseconds % 1_000_000_000) // 1_000
dt_time = (datetime.datetime.min + datetime.timedelta(seconds=seconds, microseconds=microseconds)).time()
return f"'{dt_time}'::TIME"
# Date type
if dtype == "Date":
days_since_epoch = value["Date"]
assert isinstance(days_since_epoch, (float, int)), (
f"A {dtype} should be a number but got {type(days_since_epoch)}"
)
date = datetime.date(1970, 1, 1) + datetime.timedelta(days=days_since_epoch)
return f"'{date}'::DATE"
# Binary type
if dtype == "Binary":
bin_value = value["Binary"]
assert isinstance(bin_value, list), f"A {dtype} should be a list but got {type(bin_value)}"
binary_data = bytes(bin_value)
escaped = "".join(f"\\x{b:02x}" for b in binary_data)
return f"'{escaped}'::BLOB"
# String type
if dtype == "String" or dtype == "StringOwned":
# Some new formats may store directly under StringOwned
string_val = value.get("StringOwned", value.get("String", None))
# the string must be a string constant
return str(duckdb.ConstantExpression(string_val))
msg = f"Unsupported scalar type {dtype!s}, with value {value}"
raise NotImplementedError(msg)
msg = f"Node type: {node_type} is not implemented. {tree[node_type]}"
raise NotImplementedError(msg)
def duckdb_source(relation: duckdb.DuckDBPyRelation, schema: pl.schema.Schema) -> pl.LazyFrame:
"""A polars IO plugin for DuckDB."""
def source_generator(
with_columns: list[str] | None,
predicate: pl.Expr | None,
n_rows: int | None,
batch_size: int | None,
) -> Iterator[pl.DataFrame]:
duck_predicate = None
relation_final = relation
if with_columns is not None:
cols = ",".join(map(_escape_sql_identifier, with_columns))
relation_final = relation_final.project(cols)
if n_rows is not None:
relation_final = relation_final.limit(n_rows)
if predicate is not None:
# We have a predicate, if possible, we push it down to DuckDB
with contextlib.suppress(AssertionError, KeyError):
duck_predicate = _predicate_to_expression(predicate)
# Try to pushdown filter, if one exists
if duck_predicate is not None:
relation_final = relation_final.filter(duck_predicate)
if batch_size is None:
results = relation_final.fetch_arrow_reader()
else:
results = relation_final.fetch_arrow_reader(batch_size)
for record_batch in iter(results.read_next_batch, None):
if predicate is not None and duck_predicate is None:
# We have a predicate, but did not manage to push it down, we fallback here
yield pl.from_arrow(record_batch).filter(predicate) # type: ignore[arg-type,misc,unused-ignore]
else:
yield pl.from_arrow(record_batch) # type: ignore[misc,unused-ignore]
return register_io_source(source_generator, schema=schema)