initial commit
This commit is contained in:
285
venv/Lib/site-packages/duckdb/polars_io.py
Normal file
285
venv/Lib/site-packages/duckdb/polars_io.py
Normal file
@@ -0,0 +1,285 @@
|
||||
from __future__ import annotations # noqa: D100
|
||||
|
||||
import contextlib
|
||||
import datetime
|
||||
import json
|
||||
import typing
|
||||
from decimal import Decimal
|
||||
|
||||
import polars as pl
|
||||
from polars.io.plugins import register_io_source
|
||||
|
||||
import duckdb
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
import typing_extensions
|
||||
|
||||
_ExpressionTree: typing_extensions.TypeAlias = typing.Dict[str, typing.Union[str, int, "_ExpressionTree", typing.Any]] # noqa: UP006
|
||||
|
||||
|
||||
def _predicate_to_expression(predicate: pl.Expr) -> duckdb.Expression | None:
|
||||
"""Convert a Polars predicate expression to a DuckDB-compatible SQL expression.
|
||||
|
||||
Parameters:
|
||||
predicate (pl.Expr): A Polars expression (e.g., col("foo") > 5)
|
||||
|
||||
Returns:
|
||||
SQLExpression: A DuckDB SQL expression string equivalent.
|
||||
None: If conversion fails.
|
||||
|
||||
Example:
|
||||
>>> _predicate_to_expression(pl.col("foo") > 5)
|
||||
SQLExpression("(foo > 5)")
|
||||
"""
|
||||
# Serialize the Polars expression tree to JSON
|
||||
tree = json.loads(predicate.meta.serialize(format="json"))
|
||||
|
||||
try:
|
||||
# Convert the tree to SQL
|
||||
sql_filter = _pl_tree_to_sql(tree)
|
||||
return duckdb.SQLExpression(sql_filter)
|
||||
except Exception:
|
||||
# If the conversion fails, we return None
|
||||
return None
|
||||
|
||||
|
||||
def _pl_operation_to_sql(op: str) -> str:
|
||||
"""Map Polars binary operation strings to SQL equivalents.
|
||||
|
||||
Example:
|
||||
>>> _pl_operation_to_sql("Eq")
|
||||
'='
|
||||
"""
|
||||
try:
|
||||
return {
|
||||
"Lt": "<",
|
||||
"LtEq": "<=",
|
||||
"Gt": ">",
|
||||
"GtEq": ">=",
|
||||
"Eq": "=",
|
||||
"Modulus": "%",
|
||||
"And": "AND",
|
||||
"Or": "OR",
|
||||
}[op]
|
||||
except KeyError:
|
||||
raise NotImplementedError(op) # noqa: B904
|
||||
|
||||
|
||||
def _escape_sql_identifier(identifier: str) -> str:
|
||||
"""Escape SQL identifiers by doubling any double quotes and wrapping in double quotes.
|
||||
|
||||
Example:
|
||||
>>> _escape_sql_identifier('column"name')
|
||||
'"column""name"'
|
||||
"""
|
||||
escaped = identifier.replace('"', '""')
|
||||
return f'"{escaped}"'
|
||||
|
||||
|
||||
def _pl_tree_to_sql(tree: _ExpressionTree) -> str:
|
||||
"""Recursively convert a Polars expression tree (as JSON) to a SQL string.
|
||||
|
||||
Parameters:
|
||||
tree (dict): JSON-deserialized expression tree from Polars
|
||||
|
||||
Returns:
|
||||
str: SQL expression string
|
||||
|
||||
Example:
|
||||
Input tree:
|
||||
{
|
||||
"BinaryExpr": {
|
||||
"left": { "Column": "foo" },
|
||||
"op": "Gt",
|
||||
"right": { "Literal": { "Int": 5 } }
|
||||
}
|
||||
}
|
||||
Output: "(foo > 5)"
|
||||
"""
|
||||
[node_type] = tree.keys()
|
||||
|
||||
if node_type == "BinaryExpr":
|
||||
# Binary expressions: left OP right
|
||||
bin_expr_tree = tree[node_type]
|
||||
assert isinstance(bin_expr_tree, dict), f"A {node_type} should be a dict but got {type(bin_expr_tree)}"
|
||||
lhs, op, rhs = bin_expr_tree["left"], bin_expr_tree["op"], bin_expr_tree["right"]
|
||||
assert isinstance(lhs, dict), f"LHS of a {node_type} should be a dict but got {type(lhs)}"
|
||||
assert isinstance(op, str), f"The op of a {node_type} should be a str but got {type(op)}"
|
||||
assert isinstance(rhs, dict), f"RHS of a {node_type} should be a dict but got {type(rhs)}"
|
||||
return f"({_pl_tree_to_sql(lhs)} {_pl_operation_to_sql(op)} {_pl_tree_to_sql(rhs)})"
|
||||
if node_type == "Column":
|
||||
# A reference to a column name
|
||||
# Wrap in quotes to handle special characters
|
||||
col_name = tree[node_type]
|
||||
assert isinstance(col_name, str), f"The col name of a {node_type} should be a str but got {type(col_name)}"
|
||||
return _escape_sql_identifier(col_name)
|
||||
|
||||
if node_type in ("Literal", "Dyn"):
|
||||
# Recursively process dynamic or literal values
|
||||
val_tree = tree[node_type]
|
||||
assert isinstance(val_tree, dict), f"A {node_type} should be a dict but got {type(val_tree)}"
|
||||
return _pl_tree_to_sql(val_tree)
|
||||
|
||||
if node_type == "Int":
|
||||
# Direct integer literals
|
||||
int_literal = tree[node_type]
|
||||
assert isinstance(int_literal, (int, str)), (
|
||||
f"The value of an Int should be an int or str but got {type(int_literal)}"
|
||||
)
|
||||
return str(int_literal)
|
||||
|
||||
if node_type == "Function":
|
||||
# Handle boolean functions like IsNull, IsNotNull
|
||||
func_tree = tree[node_type]
|
||||
assert isinstance(func_tree, dict), f"A {node_type} should be a dict but got {type(func_tree)}"
|
||||
inputs = func_tree["input"]
|
||||
assert isinstance(inputs, list), f"A {node_type} should have a list of dicts as input but got {type(inputs)}"
|
||||
input_tree = inputs[0]
|
||||
assert isinstance(input_tree, dict), (
|
||||
f"A {node_type} should have a list of dicts as input but got {type(input_tree)}"
|
||||
)
|
||||
func_dict = func_tree["function"]
|
||||
assert isinstance(func_dict, dict), (
|
||||
f"A {node_type} should have a function dict as input but got {type(func_dict)}"
|
||||
)
|
||||
|
||||
if "Boolean" in func_dict:
|
||||
func = func_dict["Boolean"]
|
||||
arg_sql = _pl_tree_to_sql(inputs[0])
|
||||
|
||||
if func == "IsNull":
|
||||
return f"({arg_sql} IS NULL)"
|
||||
if func == "IsNotNull":
|
||||
return f"({arg_sql} IS NOT NULL)"
|
||||
msg = f"Boolean function not supported: {func}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
msg = f"Unsupported function type: {func_dict}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if node_type == "Scalar":
|
||||
# Detect format: old style (dtype/value) or new style (direct type key)
|
||||
scalar_tree = tree[node_type]
|
||||
assert isinstance(scalar_tree, dict), f"A {node_type} should be a dict but got {type(scalar_tree)}"
|
||||
if "dtype" in scalar_tree and "value" in scalar_tree:
|
||||
dtype = str(scalar_tree["dtype"])
|
||||
value = scalar_tree["value"]
|
||||
else:
|
||||
# New style: dtype is the single key in the dict
|
||||
dtype = next(iter(scalar_tree.keys()))
|
||||
value = scalar_tree
|
||||
assert isinstance(dtype, str), f"A {node_type} should have a str dtype but got {type(dtype)}"
|
||||
assert isinstance(value, dict), f"A {node_type} should have a dict value but got {type(value)}"
|
||||
|
||||
# Decimal support
|
||||
if dtype.startswith("{'Decimal'") or dtype == "Decimal":
|
||||
decimal_value = value["Decimal"]
|
||||
assert isinstance(decimal_value, list), (
|
||||
f"A {dtype} should be a two or three member list but got {type(decimal_value)}"
|
||||
)
|
||||
assert 2 <= len(decimal_value) <= 3, (
|
||||
f"A {dtype} should be a two or three member list but got {len(decimal_value)} member list"
|
||||
)
|
||||
return str(Decimal(decimal_value[0]) / Decimal(10 ** decimal_value[-1]))
|
||||
|
||||
# Datetime with microseconds since epoch
|
||||
if dtype.startswith("{'Datetime'") or dtype == "Datetime":
|
||||
micros = value["Datetime"]
|
||||
assert isinstance(micros, list), f"A {dtype} should be a one member list but got {type(micros)}"
|
||||
dt_timestamp = datetime.datetime.fromtimestamp(micros[0] / 1_000_000, tz=datetime.timezone.utc)
|
||||
return f"'{dt_timestamp!s}'::TIMESTAMP"
|
||||
|
||||
# Match simple numeric/boolean types
|
||||
if dtype in (
|
||||
"Int8",
|
||||
"Int16",
|
||||
"Int32",
|
||||
"Int64",
|
||||
"UInt8",
|
||||
"UInt16",
|
||||
"UInt32",
|
||||
"UInt64",
|
||||
"Float32",
|
||||
"Float64",
|
||||
"Boolean",
|
||||
):
|
||||
return str(value[dtype])
|
||||
|
||||
# Time type
|
||||
if dtype == "Time":
|
||||
nanoseconds = value["Time"]
|
||||
assert isinstance(nanoseconds, int), f"A {dtype} should be an int but got {type(nanoseconds)}"
|
||||
seconds = nanoseconds // 1_000_000_000
|
||||
microseconds = (nanoseconds % 1_000_000_000) // 1_000
|
||||
dt_time = (datetime.datetime.min + datetime.timedelta(seconds=seconds, microseconds=microseconds)).time()
|
||||
return f"'{dt_time}'::TIME"
|
||||
|
||||
# Date type
|
||||
if dtype == "Date":
|
||||
days_since_epoch = value["Date"]
|
||||
assert isinstance(days_since_epoch, (float, int)), (
|
||||
f"A {dtype} should be a number but got {type(days_since_epoch)}"
|
||||
)
|
||||
date = datetime.date(1970, 1, 1) + datetime.timedelta(days=days_since_epoch)
|
||||
return f"'{date}'::DATE"
|
||||
|
||||
# Binary type
|
||||
if dtype == "Binary":
|
||||
bin_value = value["Binary"]
|
||||
assert isinstance(bin_value, list), f"A {dtype} should be a list but got {type(bin_value)}"
|
||||
binary_data = bytes(bin_value)
|
||||
escaped = "".join(f"\\x{b:02x}" for b in binary_data)
|
||||
return f"'{escaped}'::BLOB"
|
||||
|
||||
# String type
|
||||
if dtype == "String" or dtype == "StringOwned":
|
||||
# Some new formats may store directly under StringOwned
|
||||
string_val = value.get("StringOwned", value.get("String", None))
|
||||
# the string must be a string constant
|
||||
return str(duckdb.ConstantExpression(string_val))
|
||||
|
||||
msg = f"Unsupported scalar type {dtype!s}, with value {value}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
msg = f"Node type: {node_type} is not implemented. {tree[node_type]}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def duckdb_source(relation: duckdb.DuckDBPyRelation, schema: pl.schema.Schema) -> pl.LazyFrame:
|
||||
"""A polars IO plugin for DuckDB."""
|
||||
|
||||
def source_generator(
|
||||
with_columns: list[str] | None,
|
||||
predicate: pl.Expr | None,
|
||||
n_rows: int | None,
|
||||
batch_size: int | None,
|
||||
) -> Iterator[pl.DataFrame]:
|
||||
duck_predicate = None
|
||||
relation_final = relation
|
||||
if with_columns is not None:
|
||||
cols = ",".join(map(_escape_sql_identifier, with_columns))
|
||||
relation_final = relation_final.project(cols)
|
||||
if n_rows is not None:
|
||||
relation_final = relation_final.limit(n_rows)
|
||||
if predicate is not None:
|
||||
# We have a predicate, if possible, we push it down to DuckDB
|
||||
with contextlib.suppress(AssertionError, KeyError):
|
||||
duck_predicate = _predicate_to_expression(predicate)
|
||||
# Try to pushdown filter, if one exists
|
||||
if duck_predicate is not None:
|
||||
relation_final = relation_final.filter(duck_predicate)
|
||||
if batch_size is None:
|
||||
results = relation_final.fetch_arrow_reader()
|
||||
else:
|
||||
results = relation_final.fetch_arrow_reader(batch_size)
|
||||
|
||||
for record_batch in iter(results.read_next_batch, None):
|
||||
if predicate is not None and duck_predicate is None:
|
||||
# We have a predicate, but did not manage to push it down, we fallback here
|
||||
yield pl.from_arrow(record_batch).filter(predicate) # type: ignore[arg-type,misc,unused-ignore]
|
||||
else:
|
||||
yield pl.from_arrow(record_batch) # type: ignore[misc,unused-ignore]
|
||||
|
||||
return register_io_source(source_generator, schema=schema)
|
||||
Reference in New Issue
Block a user