initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
from .catalog import Catalog # noqa: D104
from .conf import RuntimeConfig
from .dataframe import DataFrame
from .readwriter import DataFrameWriter
from .session import SparkSession
__all__ = ["Catalog", "DataFrame", "DataFrameWriter", "RuntimeConfig", "SparkSession"]

View File

@@ -0,0 +1,86 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import (
Any,
Callable,
Optional,
TypeVar,
Union,
)
try:
from typing import Literal, Protocol
except ImportError:
from typing_extensions import Literal, Protocol
import datetime
import decimal
from .._typing import PrimitiveType
from . import types
from .column import Column
ColumnOrName = Union[Column, str]
ColumnOrName_ = TypeVar("ColumnOrName_", bound=ColumnOrName)
DecimalLiteral = decimal.Decimal
DateTimeLiteral = Union[datetime.datetime, datetime.date]
LiteralType = PrimitiveType
AtomicDataTypeOrString = Union[types.AtomicType, str]
DataTypeOrString = Union[types.DataType, str]
OptionalPrimitiveType = Optional[PrimitiveType]
AtomicValue = TypeVar(
"AtomicValue",
datetime.datetime,
datetime.date,
decimal.Decimal,
bool,
str,
int,
float,
)
RowLike = TypeVar("RowLike", list[Any], tuple[Any, ...], types.Row)
SQLBatchedUDFType = Literal[100]
class SupportsOpen(Protocol):
def open(self, partition_id: int, epoch_id: int) -> bool: ...
class SupportsProcess(Protocol):
def process(self, row: types.Row) -> None: ...
class SupportsClose(Protocol):
def close(self, error: Exception) -> None: ...
class UserDefinedFunctionLike(Protocol):
func: Callable[..., Any]
evalType: int
deterministic: bool
@property
def returnType(self) -> types.DataType: ...
def __call__(self, *args: ColumnOrName) -> Column: ...
def asNondeterministic(self) -> "UserDefinedFunctionLike": ...

View File

@@ -0,0 +1,79 @@
from typing import NamedTuple, Optional, Union # noqa: D100
from .session import SparkSession
class Database(NamedTuple): # noqa: D101
name: str
description: Optional[str]
locationUri: str
class Table(NamedTuple): # noqa: D101
name: str
database: Optional[str]
description: Optional[str]
tableType: str
isTemporary: bool
class Column(NamedTuple): # noqa: D101
name: str
description: Optional[str]
dataType: str
nullable: bool
isPartition: bool
isBucket: bool
class Function(NamedTuple): # noqa: D101
name: str
description: Optional[str]
className: str
isTemporary: bool
class Catalog: # noqa: D101
def __init__(self, session: SparkSession) -> None: # noqa: D107
self._session = session
def listDatabases(self) -> list[Database]: # noqa: D102
res = self._session.conn.sql("select database_name from duckdb_databases()").fetchall()
def transform_to_database(x: list[str]) -> Database:
return Database(name=x[0], description=None, locationUri="")
databases = [transform_to_database(x) for x in res]
return databases
def listTables(self) -> list[Table]: # noqa: D102
res = self._session.conn.sql("select table_name, database_name, sql, temporary from duckdb_tables()").fetchall()
def transform_to_table(x: list[str]) -> Table:
return Table(name=x[0], database=x[1], description=x[2], tableType="", isTemporary=x[3])
tables = [transform_to_table(x) for x in res]
return tables
def listColumns(self, tableName: str, dbName: Optional[str] = None) -> list[Column]: # noqa: D102
query = f"""
select column_name, data_type, is_nullable from duckdb_columns() where table_name = '{tableName}'
"""
if dbName:
query += f" and database_name = '{dbName}'"
res = self._session.conn.sql(query).fetchall()
def transform_to_column(x: list[Union[str, bool]]) -> Column:
return Column(name=x[0], description=None, dataType=x[1], nullable=x[2], isPartition=False, isBucket=False)
columns = [transform_to_column(x) for x in res]
return columns
def listFunctions(self, dbName: Optional[str] = None) -> list[Function]: # noqa: D102
raise NotImplementedError
def setCurrentDatabase(self, dbName: str) -> None: # noqa: D102
raise NotImplementedError
__all__ = ["Catalog", "Column", "Database", "Function", "Table"]

View File

@@ -0,0 +1,361 @@
from collections.abc import Iterable # noqa: D100
from typing import TYPE_CHECKING, Any, Callable, Union, cast
from ..exception import ContributionsAcceptedError
from .types import DataType
if TYPE_CHECKING:
from ._typing import DateTimeLiteral, DecimalLiteral, LiteralType
from duckdb import ColumnExpression, ConstantExpression, Expression, FunctionExpression
from duckdb.sqltypes import DuckDBPyType
__all__ = ["Column"]
def _get_expr(x: Union["Column", str]) -> Expression:
return x.expr if isinstance(x, Column) else ConstantExpression(x)
def _func_op(name: str, doc: str = "") -> Callable[["Column"], "Column"]:
def _(self: "Column") -> "Column":
njc = getattr(self.expr, name)()
return Column(njc)
_.__doc__ = doc
return _
def _unary_op(
name: str,
doc: str = "unary operator",
) -> Callable[["Column"], "Column"]:
"""Create a method for given unary operator."""
def _(self: "Column") -> "Column":
# Call the function identified by 'name' on the internal Expression object
expr = getattr(self.expr, name)()
return Column(expr)
_.__doc__ = doc
return _
def _bin_op(
name: str,
doc: str = "binary operator",
) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"]], "Column"]:
"""Create a method for given binary operator."""
def _(
self: "Column",
other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
) -> "Column":
jc = _get_expr(other)
njc = getattr(self.expr, name)(jc)
return Column(njc)
_.__doc__ = doc
return _
def _bin_func(
name: str,
doc: str = "binary function",
) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"]], "Column"]:
"""Create a function expression for the given binary function."""
def _(
self: "Column",
other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
) -> "Column":
other = _get_expr(other)
func = FunctionExpression(name, self.expr, other)
return Column(func)
_.__doc__ = doc
return _
class Column:
"""A column in a DataFrame.
:class:`Column` instances can be created by::
# 1. Select a column out of a DataFrame
df.colName
df["colName"]
# 2. Create from an expression
df.colName + 1
1 / df.colName
.. versionadded:: 1.3.0
"""
def __init__(self, expr: Expression) -> None: # noqa: D107
self.expr = expr
# arithmetic operators
def __neg__(self) -> "Column": # noqa: D105
return Column(-self.expr)
# `and`, `or`, `not` cannot be overloaded in Python,
# so use bitwise operators as boolean operators
__and__ = _bin_op("__and__")
__or__ = _bin_op("__or__")
__invert__ = _func_op("__invert__")
__rand__ = _bin_op("__rand__")
__ror__ = _bin_op("__ror__")
__add__ = _bin_op("__add__")
__sub__ = _bin_op("__sub__")
__mul__ = _bin_op("__mul__")
__div__ = _bin_op("__div__")
__truediv__ = _bin_op("__truediv__")
__mod__ = _bin_op("__mod__")
__pow__ = _bin_op("__pow__")
__radd__ = _bin_op("__radd__")
__rsub__ = _bin_op("__rsub__")
__rmul__ = _bin_op("__rmul__")
__rdiv__ = _bin_op("__rdiv__")
__rtruediv__ = _bin_op("__rtruediv__")
__rmod__ = _bin_op("__rmod__")
__rpow__ = _bin_op("__rpow__")
def __getitem__(self, k: Any) -> "Column": # noqa: ANN401
"""An expression that gets an item at position ``ordinal`` out of a list,
or gets an item by key out of a dict.
.. versionadded:: 1.3.0
.. versionchanged:: 3.4.0
Supports Spark Connect.
Parameters
----------
k
a literal value, or a slice object without step.
Returns:
-------
:class:`Column`
Column representing the item got by key out of a dict, or substrings sliced by
the given slice object.
Examples:
--------
>>> df = spark.createDataFrame([("abcedfg", {"key": "value"})], ["l", "d"])
>>> df.select(df.l[slice(1, 3)], df.d["key"]).show()
+------------------+------+
|substring(l, 1, 3)|d[key]|
+------------------+------+
| abc| value|
+------------------+------+
""" # noqa: D205
if isinstance(k, slice):
raise ContributionsAcceptedError
# if k.step is not None:
# raise ValueError("Using a slice with a step value is not supported")
# return self.substr(k.start, k.stop)
else:
# TODO: this is super hacky # noqa: TD002, TD003
expr_str = str(self.expr) + "." + str(k)
return Column(ColumnExpression(expr_str))
def __getattr__(self, item: Any) -> "Column": # noqa: ANN401
"""An expression that gets an item at position ``ordinal`` out of a list,
or gets an item by key out of a dict.
Parameters
----------
item
a literal value.
Returns:
-------
:class:`Column`
Column representing the item got by key out of a dict.
Examples:
--------
>>> df = spark.createDataFrame([("abcedfg", {"key": "value"})], ["l", "d"])
>>> df.select(df.d.key).show()
+------+
|d[key]|
+------+
| value|
+------+
""" # noqa: D205
if item.startswith("__"):
msg = "Can not access __ (dunder) method"
raise AttributeError(msg)
return self[item]
def alias(self, alias: str) -> "Column": # noqa: D102
return Column(self.expr.alias(alias))
def when(self, condition: "Column", value: Union["Column", str]) -> "Column": # noqa: D102
if not isinstance(condition, Column):
msg = "condition should be a Column"
raise TypeError(msg)
v = _get_expr(value)
expr = self.expr.when(condition.expr, v)
return Column(expr)
def otherwise(self, value: Union["Column", str]) -> "Column": # noqa: D102
v = _get_expr(value)
expr = self.expr.otherwise(v)
return Column(expr)
def cast(self, dataType: Union[DataType, str]) -> "Column": # noqa: D102
internal_type = DuckDBPyType(dataType) if isinstance(dataType, str) else dataType.duckdb_type
return Column(self.expr.cast(internal_type))
def isin(self, *cols: Union[Iterable[Union["Column", str]], Union["Column", str]]) -> "Column": # noqa: D102
if len(cols) == 1 and isinstance(cols[0], (list, set)):
# Only one argument supplied, it's a list
cols = cast("tuple", cols[0])
cols = cast(
"tuple",
[_get_expr(c) for c in cols],
)
return Column(self.expr.isin(*cols))
# logistic operators
def __eq__( # type: ignore[override]
self,
other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
) -> "Column":
"""Binary function."""
return Column(self.expr == (_get_expr(other)))
def __ne__( # type: ignore[override]
self,
other: object,
) -> "Column":
"""Binary function."""
return Column(self.expr != (_get_expr(other)))
__lt__ = _bin_op("__lt__")
__le__ = _bin_op("__le__")
__ge__ = _bin_op("__ge__")
__gt__ = _bin_op("__gt__")
# String interrogation methods
contains = _bin_func("contains")
rlike = _bin_func("regexp_matches")
like = _bin_func("~~")
ilike = _bin_func("~~*")
startswith = _bin_func("starts_with")
endswith = _bin_func("suffix")
# order
_asc_doc = """
Returns a sort expression based on the ascending order of the column.
Examples
--------
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.asc()).collect()
[Row(name='Alice'), Row(name='Tom')]
"""
_asc_nulls_first_doc = """
Returns a sort expression based on ascending order of the column, and null values
return before non-null values.
Examples
--------
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.asc_nulls_first()).collect()
[Row(name=None), Row(name='Alice'), Row(name='Tom')]
"""
_asc_nulls_last_doc = """
Returns a sort expression based on ascending order of the column, and null values
appear after non-null values.
Examples
--------
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.asc_nulls_last()).collect()
[Row(name='Alice'), Row(name='Tom'), Row(name=None)]
"""
_desc_doc = """
Returns a sort expression based on the descending order of the column.
Examples
--------
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.desc()).collect()
[Row(name='Tom'), Row(name='Alice')]
"""
_desc_nulls_first_doc = """
Returns a sort expression based on the descending order of the column, and null values
appear before non-null values.
Examples
--------
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.desc_nulls_first()).collect()
[Row(name=None), Row(name='Tom'), Row(name='Alice')]
"""
_desc_nulls_last_doc = """
Returns a sort expression based on the descending order of the column, and null values
appear after non-null values.
Examples
--------
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.desc_nulls_last()).collect()
[Row(name='Tom'), Row(name='Alice'), Row(name=None)]
"""
asc = _unary_op("asc", _asc_doc)
desc = _unary_op("desc", _desc_doc)
nulls_first = _unary_op("nulls_first")
nulls_last = _unary_op("nulls_last")
def asc_nulls_first(self) -> "Column": # noqa: D102
return self.asc().nulls_first()
def asc_nulls_last(self) -> "Column": # noqa: D102
return self.asc().nulls_last()
def desc_nulls_first(self) -> "Column": # noqa: D102
return self.desc().nulls_first()
def desc_nulls_last(self) -> "Column": # noqa: D102
return self.desc().nulls_last()
def isNull(self) -> "Column": # noqa: D102
return Column(self.expr.isnull())
def isNotNull(self) -> "Column": # noqa: D102
return Column(self.expr.isnotnull())

View File

@@ -0,0 +1,24 @@
from typing import Optional, Union # noqa: D100
from duckdb import DuckDBPyConnection
from duckdb.experimental.spark._globals import _NoValue, _NoValueType
class RuntimeConfig: # noqa: D101
def __init__(self, connection: DuckDBPyConnection) -> None: # noqa: D107
self._connection = connection
def set(self, key: str, value: str) -> None: # noqa: D102
raise NotImplementedError
def isModifiable(self, key: str) -> bool: # noqa: D102
raise NotImplementedError
def unset(self, key: str) -> None: # noqa: D102
raise NotImplementedError
def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue) -> str: # noqa: D102
raise NotImplementedError
__all__ = ["RuntimeConfig"]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,424 @@
# # noqa: D100
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import TYPE_CHECKING, Callable, Union, overload
from ..exception import ContributionsAcceptedError
from .column import Column
from .dataframe import DataFrame
from .functions import _to_column_expr
from .types import NumericType
# Only import symbols needed for type checking if something is type checking
if TYPE_CHECKING:
from ._typing import ColumnOrName
from .session import SparkSession
__all__ = ["GroupedData", "Grouping"]
def _api_internal(self: "GroupedData", name: str, *cols: str) -> DataFrame:
expressions = ",".join(list(cols))
group_by = str(self._grouping) if self._grouping else ""
projections = self._grouping.get_columns()
jdf = self._df.relation.apply(
function_name=name, # aggregate function
function_aggr=expressions, # inputs to aggregate
group_expr=group_by, # groups
projected_columns=projections, # projections
)
return DataFrame(jdf, self.session)
def df_varargs_api(f: Callable[..., DataFrame]) -> Callable[..., DataFrame]:
def _api(self: "GroupedData", *cols: str) -> DataFrame:
name = f.__name__
return _api_internal(self, name, *cols)
_api.__name__ = f.__name__
_api.__doc__ = f.__doc__
return _api
class Grouping: # noqa: D101
def __init__(self, *cols: "ColumnOrName", **kwargs) -> None: # noqa: D107
self._type = ""
self._cols = [_to_column_expr(x) for x in cols]
if "special" in kwargs:
special = kwargs["special"]
accepted_special = ["cube", "rollup"]
assert special in accepted_special
self._type = special
def get_columns(self) -> str: # noqa: D102
columns = ",".join([str(x) for x in self._cols])
return columns
def __str__(self) -> str: # noqa: D105
columns = self.get_columns()
if self._type:
return self._type + "(" + columns + ")"
return columns
class GroupedData:
"""A set of methods for aggregations on a :class:`DataFrame`,
created by :func:`DataFrame.groupBy`.
""" # noqa: D205
def __init__(self, grouping: Grouping, df: DataFrame) -> None: # noqa: D107
self._grouping = grouping
self._df = df
self.session: SparkSession = df.session
def __repr__(self) -> str: # noqa: D105
return str(self._df)
def count(self) -> DataFrame:
"""Counts the number of records for each group.
Examples:
--------
>>> df = spark.createDataFrame(
... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"]
... )
>>> df.show()
+---+-----+
|age| name|
+---+-----+
| 2|Alice|
| 3|Alice|
| 5| Bob|
| 10| Bob|
+---+-----+
Group-by name, and count each group.
>>> df.groupBy(df.name).count().sort("name").show()
+-----+-----+
| name|count|
+-----+-----+
|Alice| 2|
| Bob| 2|
+-----+-----+
"""
return _api_internal(self, "count").withColumnRenamed("count_star()", "count")
@df_varargs_api
def mean(self, *cols: str) -> DataFrame:
"""Computes average values for each numeric columns for each group.
:func:`mean` is an alias for :func:`avg`.
Parameters
----------
cols : str
column names. Non-numeric columns are ignored.
"""
def avg(self, *cols: str) -> DataFrame:
"""Computes average values for each numeric columns for each group.
:func:`mean` is an alias for :func:`avg`.
Parameters
----------
cols : str
column names. Non-numeric columns are ignored.
Examples:
--------
>>> df = spark.createDataFrame(
... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)],
... ["age", "name", "height"],
... )
>>> df.show()
+---+-----+------+
|age| name|height|
+---+-----+------+
| 2|Alice| 80|
| 3|Alice| 100|
| 5| Bob| 120|
| 10| Bob| 140|
+---+-----+------+
Group-by name, and calculate the mean of the age in each group.
>>> df.groupBy("name").avg("age").sort("name").show()
+-----+--------+
| name|avg(age)|
+-----+--------+
|Alice| 2.5|
| Bob| 7.5|
+-----+--------+
Calculate the mean of the age and height in all data.
>>> df.groupBy().avg("age", "height").show()
+--------+-----------+
|avg(age)|avg(height)|
+--------+-----------+
| 5.0| 110.0|
+--------+-----------+
"""
columns = list(cols)
if len(columns) == 0:
schema = self._df.schema
# Take only the numeric types of the relation
columns: list[str] = [x.name for x in schema.fields if isinstance(x.dataType, NumericType)]
return _api_internal(self, "avg", *columns)
@df_varargs_api
def max(self, *cols: str) -> DataFrame:
"""Computes the max value for each numeric columns for each group.
Examples:
--------
>>> df = spark.createDataFrame(
... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)],
... ["age", "name", "height"],
... )
>>> df.show()
+---+-----+------+
|age| name|height|
+---+-----+------+
| 2|Alice| 80|
| 3|Alice| 100|
| 5| Bob| 120|
| 10| Bob| 140|
+---+-----+------+
Group-by name, and calculate the max of the age in each group.
>>> df.groupBy("name").max("age").sort("name").show()
+-----+--------+
| name|max(age)|
+-----+--------+
|Alice| 3|
| Bob| 10|
+-----+--------+
Calculate the max of the age and height in all data.
>>> df.groupBy().max("age", "height").show()
+--------+-----------+
|max(age)|max(height)|
+--------+-----------+
| 10| 140|
+--------+-----------+
"""
@df_varargs_api
def min(self, *cols: str) -> DataFrame:
"""Computes the min value for each numeric column for each group.
Parameters
----------
cols : str
column names. Non-numeric columns are ignored.
Examples:
--------
>>> df = spark.createDataFrame(
... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)],
... ["age", "name", "height"],
... )
>>> df.show()
+---+-----+------+
|age| name|height|
+---+-----+------+
| 2|Alice| 80|
| 3|Alice| 100|
| 5| Bob| 120|
| 10| Bob| 140|
+---+-----+------+
Group-by name, and calculate the min of the age in each group.
>>> df.groupBy("name").min("age").sort("name").show()
+-----+--------+
| name|min(age)|
+-----+--------+
|Alice| 2|
| Bob| 5|
+-----+--------+
Calculate the min of the age and height in all data.
>>> df.groupBy().min("age", "height").show()
+--------+-----------+
|min(age)|min(height)|
+--------+-----------+
| 2| 80|
+--------+-----------+
"""
@df_varargs_api
def sum(self, *cols: str) -> DataFrame:
"""Computes the sum for each numeric columns for each group.
Parameters
----------
cols : str
column names. Non-numeric columns are ignored.
Examples:
--------
>>> df = spark.createDataFrame(
... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)],
... ["age", "name", "height"],
... )
>>> df.show()
+---+-----+------+
|age| name|height|
+---+-----+------+
| 2|Alice| 80|
| 3|Alice| 100|
| 5| Bob| 120|
| 10| Bob| 140|
+---+-----+------+
Group-by name, and calculate the sum of the age in each group.
>>> df.groupBy("name").sum("age").sort("name").show()
+-----+--------+
| name|sum(age)|
+-----+--------+
|Alice| 5|
| Bob| 15|
+-----+--------+
Calculate the sum of the age and height in all data.
>>> df.groupBy().sum("age", "height").show()
+--------+-----------+
|sum(age)|sum(height)|
+--------+-----------+
| 20| 440|
+--------+-----------+
"""
@overload
def agg(self, *exprs: Column) -> DataFrame: ...
@overload
def agg(self, __exprs: dict[str, str]) -> DataFrame: ... # noqa: PYI063
def agg(self, *exprs: Union[Column, dict[str, str]]) -> DataFrame:
"""Compute aggregates and returns the result as a :class:`DataFrame`.
The available aggregate functions can be:
1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count`
2. group aggregate pandas UDFs, created with :func:`pyspark.sql.functions.pandas_udf`
.. note:: There is no partial aggregation with group aggregate UDFs, i.e.,
a full shuffle is required. Also, all the data of a group will be loaded into
memory, so the user should be aware of the potential OOM risk if data is skewed
and certain groups are too large to fit in memory.
.. seealso:: :func:`pyspark.sql.functions.pandas_udf`
If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
is the column to perform aggregation on, and the value is the aggregate function.
Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.
.. versionadded:: 1.3.0
.. versionchanged:: 3.4.0
Supports Spark Connect.
Parameters
----------
exprs : dict
a dict mapping from column name (string) to aggregate functions (string),
or a list of :class:`Column`.
Notes:
-----
Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed
in a single call to this function.
Examples:
--------
>>> from pyspark.sql import functions as F
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"]
... )
>>> df.show()
+---+-----+
|age| name|
+---+-----+
| 2|Alice|
| 3|Alice|
| 5| Bob|
| 10| Bob|
+---+-----+
Group-by name, and count each group.
>>> df.groupBy(df.name)
GroupedData[grouping...: [name...], value: [age: bigint, name: string], type: GroupBy]
>>> df.groupBy(df.name).agg({"*": "count"}).sort("name").show()
+-----+--------+
| name|count(1)|
+-----+--------+
|Alice| 2|
| Bob| 2|
+-----+--------+
Group-by name, and calculate the minimum age.
>>> df.groupBy(df.name).agg(F.min(df.age)).sort("name").show()
+-----+--------+
| name|min(age)|
+-----+--------+
|Alice| 2|
| Bob| 5|
+-----+--------+
Same as above but uses pandas UDF.
>>> @pandas_udf("int", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
... def min_udf(v):
... return v.min()
>>> df.groupBy(df.name).agg(min_udf(df.age)).sort("name").show() # doctest: +SKIP
+-----+------------+
| name|min_udf(age)|
+-----+------------+
|Alice| 2|
| Bob| 5|
+-----+------------+
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
raise ContributionsAcceptedError
else:
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
expressions = list(self._grouping._cols)
expressions.extend([x.expr for x in exprs])
group_by = str(self._grouping)
rel = self._df.relation.select(*expressions, groups=group_by)
return DataFrame(rel, self.session)
# TODO: add 'pivot' # noqa: TD002, TD003

View File

@@ -0,0 +1,435 @@
from typing import TYPE_CHECKING, Optional, Union, cast # noqa: D100
from ..errors import PySparkNotImplementedError, PySparkTypeError
from ..exception import ContributionsAcceptedError
from .types import StructType
PrimitiveType = Union[bool, float, int, str]
OptionalPrimitiveType = Optional[PrimitiveType]
if TYPE_CHECKING:
from duckdb.experimental.spark.sql.dataframe import DataFrame
from duckdb.experimental.spark.sql.session import SparkSession
class DataFrameWriter: # noqa: D101
def __init__(self, dataframe: "DataFrame") -> None: # noqa: D107
self.dataframe = dataframe
def saveAsTable(self, table_name: str) -> None: # noqa: D102
relation = self.dataframe.relation
relation.create(table_name)
def parquet( # noqa: D102
self,
path: str,
mode: Optional[str] = None,
partitionBy: Union[str, list[str], None] = None,
compression: Optional[str] = None,
) -> None:
relation = self.dataframe.relation
if mode:
raise NotImplementedError
if partitionBy:
raise NotImplementedError
relation.write_parquet(path, compression=compression)
def csv( # noqa: D102
self,
path: str,
mode: Optional[str] = None,
compression: Optional[str] = None,
sep: Optional[str] = None,
quote: Optional[str] = None,
escape: Optional[str] = None,
header: Optional[Union[bool, str]] = None,
nullValue: Optional[str] = None,
escapeQuotes: Optional[Union[bool, str]] = None,
quoteAll: Optional[Union[bool, str]] = None,
dateFormat: Optional[str] = None,
timestampFormat: Optional[str] = None,
ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = None,
ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = None,
charToEscapeQuoteEscaping: Optional[str] = None,
encoding: Optional[str] = None,
emptyValue: Optional[str] = None,
lineSep: Optional[str] = None,
) -> None:
if mode not in (None, "overwrite"):
raise NotImplementedError
if escapeQuotes:
raise NotImplementedError
if ignoreLeadingWhiteSpace:
raise NotImplementedError
if ignoreTrailingWhiteSpace:
raise NotImplementedError
if charToEscapeQuoteEscaping:
raise NotImplementedError
if emptyValue:
raise NotImplementedError
if lineSep:
raise NotImplementedError
relation = self.dataframe.relation
relation.write_csv(
path,
sep=sep,
na_rep=nullValue,
quotechar=quote,
compression=compression,
escapechar=escape,
header=header if isinstance(header, bool) else header == "True",
encoding=encoding,
quoting=quoteAll,
date_format=dateFormat,
timestamp_format=timestampFormat,
)
class DataFrameReader: # noqa: D101
def __init__(self, session: "SparkSession") -> None: # noqa: D107
self.session = session
def load( # noqa: D102
self,
path: Optional[Union[str, list[str]]] = None,
format: Optional[str] = None,
schema: Optional[Union[StructType, str]] = None,
**options: OptionalPrimitiveType,
) -> "DataFrame":
from duckdb.experimental.spark.sql.dataframe import DataFrame
if not isinstance(path, str):
raise TypeError
if options:
raise ContributionsAcceptedError
rel = None
if format:
format = format.lower()
if format == "csv" or format == "tsv":
rel = self.session.conn.read_csv(path)
elif format == "json":
rel = self.session.conn.read_json(path)
elif format == "parquet":
rel = self.session.conn.read_parquet(path)
else:
raise ContributionsAcceptedError
else:
rel = self.session.conn.sql(f"select * from {path}")
df = DataFrame(rel, self.session)
if schema:
if not isinstance(schema, StructType):
raise ContributionsAcceptedError
schema = cast("StructType", schema)
types, names = schema.extract_types_and_names()
df = df._cast_types(types)
df = df.toDF(names)
return df
def csv( # noqa: D102
self,
path: Union[str, list[str]],
schema: Optional[Union[StructType, str]] = None,
sep: Optional[str] = None,
encoding: Optional[str] = None,
quote: Optional[str] = None,
escape: Optional[str] = None,
comment: Optional[str] = None,
header: Optional[Union[bool, str]] = None,
inferSchema: Optional[Union[bool, str]] = None,
ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = None,
ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = None,
nullValue: Optional[str] = None,
nanValue: Optional[str] = None,
positiveInf: Optional[str] = None,
negativeInf: Optional[str] = None,
dateFormat: Optional[str] = None,
timestampFormat: Optional[str] = None,
maxColumns: Optional[Union[int, str]] = None,
maxCharsPerColumn: Optional[Union[int, str]] = None,
maxMalformedLogPerPartition: Optional[Union[int, str]] = None,
mode: Optional[str] = None,
columnNameOfCorruptRecord: Optional[str] = None,
multiLine: Optional[Union[bool, str]] = None,
charToEscapeQuoteEscaping: Optional[str] = None,
samplingRatio: Optional[Union[float, str]] = None,
enforceSchema: Optional[Union[bool, str]] = None,
emptyValue: Optional[str] = None,
locale: Optional[str] = None,
lineSep: Optional[str] = None,
pathGlobFilter: Optional[Union[bool, str]] = None,
recursiveFileLookup: Optional[Union[bool, str]] = None,
modifiedBefore: Optional[Union[bool, str]] = None,
modifiedAfter: Optional[Union[bool, str]] = None,
unescapedQuoteHandling: Optional[str] = None,
) -> "DataFrame":
if not isinstance(path, str):
raise NotImplementedError
if schema and not isinstance(schema, StructType):
raise ContributionsAcceptedError
if comment:
raise ContributionsAcceptedError
if inferSchema:
raise ContributionsAcceptedError
if ignoreLeadingWhiteSpace:
raise ContributionsAcceptedError
if ignoreTrailingWhiteSpace:
raise ContributionsAcceptedError
if nanValue:
raise ConnectionAbortedError
if positiveInf:
raise ConnectionAbortedError
if negativeInf:
raise ConnectionAbortedError
if negativeInf:
raise ConnectionAbortedError
if maxColumns:
raise ContributionsAcceptedError
if maxCharsPerColumn:
raise ContributionsAcceptedError
if maxMalformedLogPerPartition:
raise ContributionsAcceptedError
if mode:
raise ContributionsAcceptedError
if columnNameOfCorruptRecord:
raise ContributionsAcceptedError
if multiLine:
raise ContributionsAcceptedError
if charToEscapeQuoteEscaping:
raise ContributionsAcceptedError
if samplingRatio:
raise ContributionsAcceptedError
if enforceSchema:
raise ContributionsAcceptedError
if emptyValue:
raise ContributionsAcceptedError
if locale:
raise ContributionsAcceptedError
if pathGlobFilter:
raise ContributionsAcceptedError
if recursiveFileLookup:
raise ContributionsAcceptedError
if modifiedBefore:
raise ContributionsAcceptedError
if modifiedAfter:
raise ContributionsAcceptedError
if unescapedQuoteHandling:
raise ContributionsAcceptedError
if lineSep:
# We have support for custom newline, just needs to be ported to 'read_csv'
raise NotImplementedError
dtype = None
names = None
if schema:
schema = cast("StructType", schema)
dtype, names = schema.extract_types_and_names()
rel = self.session.conn.read_csv(
path,
header=header if isinstance(header, bool) else header == "True",
sep=sep,
dtype=dtype,
na_values=nullValue,
quotechar=quote,
escapechar=escape,
encoding=encoding,
date_format=dateFormat,
timestamp_format=timestampFormat,
)
from ..sql.dataframe import DataFrame
df = DataFrame(rel, self.session)
if names:
df = df.toDF(*names)
return df
def parquet(self, *paths: str, **options: "OptionalPrimitiveType") -> "DataFrame": # noqa: D102
input = list(paths)
if len(input) != 1:
msg = "Only single paths are supported for now"
raise NotImplementedError(msg)
option_amount = len(options.keys())
if option_amount != 0:
msg = "Options are not supported"
raise ContributionsAcceptedError(msg)
path = input[0]
rel = self.session.conn.read_parquet(path)
from ..sql.dataframe import DataFrame
df = DataFrame(rel, self.session)
return df
def json(
self,
path: Union[str, list[str]],
schema: Optional[Union[StructType, str]] = None,
primitivesAsString: Optional[Union[bool, str]] = None,
prefersDecimal: Optional[Union[bool, str]] = None,
allowComments: Optional[Union[bool, str]] = None,
allowUnquotedFieldNames: Optional[Union[bool, str]] = None,
allowSingleQuotes: Optional[Union[bool, str]] = None,
allowNumericLeadingZero: Optional[Union[bool, str]] = None,
allowBackslashEscapingAnyCharacter: Optional[Union[bool, str]] = None,
mode: Optional[str] = None,
columnNameOfCorruptRecord: Optional[str] = None,
dateFormat: Optional[str] = None,
timestampFormat: Optional[str] = None,
multiLine: Optional[Union[bool, str]] = None,
allowUnquotedControlChars: Optional[Union[bool, str]] = None,
lineSep: Optional[str] = None,
samplingRatio: Optional[Union[float, str]] = None,
dropFieldIfAllNull: Optional[Union[bool, str]] = None,
encoding: Optional[str] = None,
locale: Optional[str] = None,
pathGlobFilter: Optional[Union[bool, str]] = None,
recursiveFileLookup: Optional[Union[bool, str]] = None,
modifiedBefore: Optional[Union[bool, str]] = None,
modifiedAfter: Optional[Union[bool, str]] = None,
allowNonNumericNumbers: Optional[Union[bool, str]] = None,
) -> "DataFrame":
"""Loads JSON files and returns the results as a :class:`DataFrame`.
`JSON Lines <http://jsonlines.org/>`_ (newline-delimited JSON) is supported by default.
For JSON (one record per file), set the ``multiLine`` parameter to ``true``.
If the ``schema`` parameter is not specified, this function goes
through the input once to determine the input schema.
.. versionadded:: 1.4.0
.. versionchanged:: 3.4.0
Supports Spark Connect.
Parameters
----------
path : str, list or :class:`RDD`
string represents path to the JSON dataset, or a list of paths,
or RDD of Strings storing JSON objects.
schema : :class:`pyspark.sql.types.StructType` or str, optional
an optional :class:`pyspark.sql.types.StructType` for the input schema or
a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
Other Parameters
----------------
Extra options
For the extra options, refer to
`Data Source Option <https://spark.apache.org/docs/latest/sql-data-sources-json.html#data-source-option>`_
for the version you use.
.. # noqa
Examples:
--------
Write a DataFrame into a JSON file and read it back.
>>> import tempfile
>>> with tempfile.TemporaryDirectory() as d:
... # Write a DataFrame into a JSON file
... spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}]).write.mode(
... "overwrite"
... ).format("json").save(d)
...
... # Read the JSON file as a DataFrame.
... spark.read.json(d).show()
+---+------------+
|age| name|
+---+------------+
|100|Hyukjin Kwon|
+---+------------+
"""
if schema is not None:
msg = "The 'schema' option is not supported"
raise ContributionsAcceptedError(msg)
if primitivesAsString is not None:
msg = "The 'primitivesAsString' option is not supported"
raise ContributionsAcceptedError(msg)
if prefersDecimal is not None:
msg = "The 'prefersDecimal' option is not supported"
raise ContributionsAcceptedError(msg)
if allowComments is not None:
msg = "The 'allowComments' option is not supported"
raise ContributionsAcceptedError(msg)
if allowUnquotedFieldNames is not None:
msg = "The 'allowUnquotedFieldNames' option is not supported"
raise ContributionsAcceptedError(msg)
if allowSingleQuotes is not None:
msg = "The 'allowSingleQuotes' option is not supported"
raise ContributionsAcceptedError(msg)
if allowNumericLeadingZero is not None:
msg = "The 'allowNumericLeadingZero' option is not supported"
raise ContributionsAcceptedError(msg)
if allowBackslashEscapingAnyCharacter is not None:
msg = "The 'allowBackslashEscapingAnyCharacter' option is not supported"
raise ContributionsAcceptedError(msg)
if mode is not None:
msg = "The 'mode' option is not supported"
raise ContributionsAcceptedError(msg)
if columnNameOfCorruptRecord is not None:
msg = "The 'columnNameOfCorruptRecord' option is not supported"
raise ContributionsAcceptedError(msg)
if dateFormat is not None:
msg = "The 'dateFormat' option is not supported"
raise ContributionsAcceptedError(msg)
if timestampFormat is not None:
msg = "The 'timestampFormat' option is not supported"
raise ContributionsAcceptedError(msg)
if multiLine is not None:
msg = "The 'multiLine' option is not supported"
raise ContributionsAcceptedError(msg)
if allowUnquotedControlChars is not None:
msg = "The 'allowUnquotedControlChars' option is not supported"
raise ContributionsAcceptedError(msg)
if lineSep is not None:
msg = "The 'lineSep' option is not supported"
raise ContributionsAcceptedError(msg)
if samplingRatio is not None:
msg = "The 'samplingRatio' option is not supported"
raise ContributionsAcceptedError(msg)
if dropFieldIfAllNull is not None:
msg = "The 'dropFieldIfAllNull' option is not supported"
raise ContributionsAcceptedError(msg)
if encoding is not None:
msg = "The 'encoding' option is not supported"
raise ContributionsAcceptedError(msg)
if locale is not None:
msg = "The 'locale' option is not supported"
raise ContributionsAcceptedError(msg)
if pathGlobFilter is not None:
msg = "The 'pathGlobFilter' option is not supported"
raise ContributionsAcceptedError(msg)
if recursiveFileLookup is not None:
msg = "The 'recursiveFileLookup' option is not supported"
raise ContributionsAcceptedError(msg)
if modifiedBefore is not None:
msg = "The 'modifiedBefore' option is not supported"
raise ContributionsAcceptedError(msg)
if modifiedAfter is not None:
msg = "The 'modifiedAfter' option is not supported"
raise ContributionsAcceptedError(msg)
if allowNonNumericNumbers is not None:
msg = "The 'allowNonNumericNumbers' option is not supported"
raise ContributionsAcceptedError(msg)
if isinstance(path, str):
path = [path]
if isinstance(path, list):
if len(path) == 1:
rel = self.session.conn.read_json(path[0])
from .dataframe import DataFrame
df = DataFrame(rel, self.session)
return df
raise PySparkNotImplementedError(message="Only a single path is supported for now")
else:
raise PySparkTypeError(
error_class="NOT_STR_OR_LIST_OF_RDD",
message_parameters={
"arg_name": "path",
"arg_type": type(path).__name__,
},
)
__all__ = ["DataFrameReader", "DataFrameWriter"]

View File

@@ -0,0 +1,297 @@
import uuid # noqa: D100
from collections.abc import Iterable, Sized
from typing import TYPE_CHECKING, Any, NoReturn, Optional, Union
import duckdb
if TYPE_CHECKING:
from pandas.core.frame import DataFrame as PandasDataFrame
from .catalog import Catalog
from ..conf import SparkConf
from ..context import SparkContext
from ..errors import PySparkTypeError
from ..exception import ContributionsAcceptedError
from .conf import RuntimeConfig
from .dataframe import DataFrame
from .readwriter import DataFrameReader
from .streaming import DataStreamReader
from .types import StructType
from .udf import UDFRegistration
# In spark:
# SparkSession holds a SparkContext
# SparkContext gets created from SparkConf
# At this level the check is made to determine whether the instance already exists and just needs
# to be retrieved or it needs to be created.
# For us this is done inside of `duckdb.connect`, based on the passed in path + configuration
# SparkContext can be compared to our Connection class, and SparkConf to our ClientContext class
# data is a List of rows
# every value in each row needs to be turned into a Value
def _combine_data_and_schema(data: Iterable[Any], schema: StructType) -> list[duckdb.Value]:
from duckdb import Value
new_data = []
for row in data:
new_row = [Value(x, dtype.duckdb_type) for x, dtype in zip(row, [y.dataType for y in schema])]
new_data.append(new_row)
return new_data
class SparkSession: # noqa: D101
def __init__(self, context: SparkContext) -> None: # noqa: D107
self.conn = context.connection
self._context = context
self._conf = RuntimeConfig(self.conn)
def _create_dataframe(self, data: Union[Iterable[Any], "PandasDataFrame"]) -> DataFrame:
try:
import pandas
has_pandas = True
except ImportError:
has_pandas = False
if has_pandas and isinstance(data, pandas.DataFrame):
unique_name = f"pyspark_pandas_df_{uuid.uuid1()}"
self.conn.register(unique_name, data)
return DataFrame(self.conn.sql(f'select * from "{unique_name}"'), self)
def verify_tuple_integrity(tuples: list[tuple]) -> None:
if len(tuples) <= 1:
return
expected_length = len(tuples[0])
for i, item in enumerate(tuples[1:]):
actual_length = len(item)
if expected_length == actual_length:
continue
raise PySparkTypeError(
error_class="LENGTH_SHOULD_BE_THE_SAME",
message_parameters={
"arg1": f"data{i}",
"arg2": f"data{i + 1}",
"arg1_length": str(expected_length),
"arg2_length": str(actual_length),
},
)
if not isinstance(data, list):
data = list(data)
verify_tuple_integrity(data)
def construct_query(tuples: Iterable) -> str:
def construct_values_list(row: Sized, start_param_idx: int) -> str:
parameter_count = len(row)
parameters = [f"${x + start_param_idx}" for x in range(parameter_count)]
parameters = "(" + ", ".join(parameters) + ")"
return parameters
row_size = len(tuples[0])
values_list = [construct_values_list(x, 1 + (i * row_size)) for i, x in enumerate(tuples)]
values_list = ", ".join(values_list)
query = f"""
select * from (values {values_list})
"""
return query
query = construct_query(data)
def construct_parameters(tuples: Iterable) -> list[list]:
parameters = []
for row in tuples:
parameters.extend(list(row))
return parameters
parameters = construct_parameters(data)
rel = self.conn.sql(query, params=parameters)
return DataFrame(rel, self)
def _createDataFrameFromPandas(
self, data: "PandasDataFrame", types: Union[list[str], None], names: Union[list[str], None]
) -> DataFrame:
df = self._create_dataframe(data)
# Cast to types
if types:
df = df._cast_types(*types)
# Alias to names
if names:
df = df.toDF(*names)
return df
def createDataFrame( # noqa: D102
self,
data: Union["PandasDataFrame", Iterable[Any]],
schema: Optional[Union[StructType, list[str]]] = None,
samplingRatio: Optional[float] = None,
verifySchema: bool = True,
) -> DataFrame:
if samplingRatio:
raise NotImplementedError
if not verifySchema:
raise NotImplementedError
types = None
names = None
if isinstance(data, DataFrame):
raise PySparkTypeError(
error_class="SHOULD_NOT_DATAFRAME",
message_parameters={"arg_name": "data"},
)
if schema:
if isinstance(schema, StructType):
types, names = schema.extract_types_and_names()
else:
names = schema
try:
import pandas
has_pandas = True
except ImportError:
has_pandas = False
# Falsey check on pandas dataframe is not defined, so first check if it's not a pandas dataframe
# Then check if 'data' is None or []
if has_pandas and isinstance(data, pandas.DataFrame):
return self._createDataFrameFromPandas(data, types, names)
# Finally check if a schema was provided
is_empty = False
if not data and names:
# Create NULLs for every type in our dataframe
is_empty = True
data = [tuple(None for _ in names)]
if schema and isinstance(schema, StructType):
# Transform the data into Values to combine the data+schema
data = _combine_data_and_schema(data, schema)
df = self._create_dataframe(data)
if is_empty:
rel = df.relation
# Add impossible where clause
rel = rel.filter("1=0")
df = DataFrame(rel, self)
# Cast to types
if types:
df = df._cast_types(*types)
# Alias to names
if names:
df = df.toDF(*names)
return df
def newSession(self) -> "SparkSession": # noqa: D102
return SparkSession(self._context)
def range( # noqa: D102
self,
start: int,
end: Optional[int] = None,
step: int = 1,
numPartitions: Optional[int] = None,
) -> "DataFrame":
if numPartitions:
raise ContributionsAcceptedError
if end is None:
end = start
start = 0
return DataFrame(self.conn.table_function("range", parameters=[start, end, step]), self)
def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame: # noqa: D102, ANN401
if kwargs:
raise NotImplementedError
relation = self.conn.sql(sqlQuery)
return DataFrame(relation, self)
def stop(self) -> None: # noqa: D102
self._context.stop()
def table(self, tableName: str) -> DataFrame: # noqa: D102
relation = self.conn.table(tableName)
return DataFrame(relation, self)
def getActiveSession(self) -> "SparkSession": # noqa: D102
return self
@property
def catalog(self) -> "Catalog": # noqa: D102
if not hasattr(self, "_catalog"):
from duckdb.experimental.spark.sql.catalog import Catalog
self._catalog = Catalog(self)
return self._catalog
@property
def conf(self) -> RuntimeConfig: # noqa: D102
return self._conf
@property
def read(self) -> DataFrameReader: # noqa: D102
return DataFrameReader(self)
@property
def readStream(self) -> DataStreamReader: # noqa: D102
return DataStreamReader(self)
@property
def sparkContext(self) -> SparkContext: # noqa: D102
return self._context
@property
def streams(self) -> NoReturn: # noqa: D102
raise ContributionsAcceptedError
@property
def udf(self) -> UDFRegistration: # noqa: D102
return UDFRegistration(self)
@property
def version(self) -> str: # noqa: D102
return "1.0.0"
class Builder: # noqa: D106
def __init__(self) -> None: # noqa: D107
pass
def master(self, name: str) -> "SparkSession.Builder": # noqa: D102
# no-op
return self
def appName(self, name: str) -> "SparkSession.Builder": # noqa: D102
# no-op
return self
def remote(self, url: str) -> "SparkSession.Builder": # noqa: D102
# no-op
return self
def getOrCreate(self) -> "SparkSession": # noqa: D102
context = SparkContext("__ignored__")
return SparkSession(context)
def config( # noqa: D102
self,
key: Optional[str] = None,
value: Optional[Any] = None, # noqa: ANN401
conf: Optional[SparkConf] = None,
) -> "SparkSession.Builder":
return self
def enableHiveSupport(self) -> "SparkSession.Builder": # noqa: D102
# no-op
return self
builder = Builder()
__all__ = ["SparkSession"]

View File

@@ -0,0 +1,36 @@
from typing import TYPE_CHECKING, Optional, Union # noqa: D100
from .types import StructType
if TYPE_CHECKING:
from .dataframe import DataFrame
from .session import SparkSession
PrimitiveType = Union[bool, float, int, str]
OptionalPrimitiveType = Optional[PrimitiveType]
class DataStreamWriter: # noqa: D101
def __init__(self, dataframe: "DataFrame") -> None: # noqa: D107
self.dataframe = dataframe
def toTable(self, table_name: str) -> None: # noqa: D102
# Should we register the dataframe or create a table from the contents?
raise NotImplementedError
class DataStreamReader: # noqa: D101
def __init__(self, session: "SparkSession") -> None: # noqa: D107
self.session = session
def load( # noqa: D102
self,
path: Optional[str] = None,
format: Optional[str] = None,
schema: Union[StructType, str, None] = None,
**options: OptionalPrimitiveType,
) -> "DataFrame":
raise NotImplementedError
__all__ = ["DataStreamReader", "DataStreamWriter"]

View File

@@ -0,0 +1,113 @@
from typing import cast # noqa: D100
from duckdb.sqltypes import DuckDBPyType
from ..exception import ContributionsAcceptedError
from .types import (
ArrayType,
BinaryType,
BitstringType,
BooleanType,
ByteType,
DataType,
DateType,
DayTimeIntervalType,
DecimalType,
DoubleType,
FloatType,
HugeIntegerType,
IntegerType,
LongType,
MapType,
ShortType,
StringType,
StructField,
StructType,
TimeNTZType,
TimestampMilisecondNTZType,
TimestampNanosecondNTZType,
TimestampNTZType,
TimestampSecondNTZType,
TimestampType,
TimeType,
UnsignedByteType,
UnsignedHugeIntegerType,
UnsignedIntegerType,
UnsignedLongType,
UnsignedShortType,
UUIDType,
)
_sqltype_to_spark_class = {
"boolean": BooleanType,
"utinyint": UnsignedByteType,
"tinyint": ByteType,
"usmallint": UnsignedShortType,
"smallint": ShortType,
"uinteger": UnsignedIntegerType,
"integer": IntegerType,
"ubigint": UnsignedLongType,
"bigint": LongType,
"hugeint": HugeIntegerType,
"uhugeint": UnsignedHugeIntegerType,
"varchar": StringType,
"blob": BinaryType,
"bit": BitstringType,
"uuid": UUIDType,
"date": DateType,
"time": TimeNTZType,
"time with time zone": TimeType,
"timestamp": TimestampNTZType,
"timestamp with time zone": TimestampType,
"timestamp_ms": TimestampNanosecondNTZType,
"timestamp_ns": TimestampMilisecondNTZType,
"timestamp_s": TimestampSecondNTZType,
"interval": DayTimeIntervalType,
"list": ArrayType,
"struct": StructType,
"map": MapType,
# union
# enum
# null (???)
"float": FloatType,
"double": DoubleType,
"decimal": DecimalType,
}
def convert_nested_type(dtype: DuckDBPyType) -> DataType: # noqa: D103
id = dtype.id
if id == "list" or id == "array":
children = dtype.children
return ArrayType(convert_type(children[0][1]))
if id == "union":
msg = (
"Union types are not supported in the PySpark interface. "
"DuckDB union types cannot be directly mapped to PySpark types."
)
raise ContributionsAcceptedError(msg)
if id == "struct":
children: list[tuple[str, DuckDBPyType]] = dtype.children
fields = [StructField(x[0], convert_type(x[1])) for x in children]
return StructType(fields)
if id == "map":
return MapType(convert_type(dtype.key), convert_type(dtype.value))
raise NotImplementedError
def convert_type(dtype: DuckDBPyType) -> DataType: # noqa: D103
id = dtype.id
if id in ["list", "struct", "map", "array"]:
return convert_nested_type(dtype)
if id == "decimal":
children: list[tuple[str, DuckDBPyType]] = dtype.children
precision = cast("int", children[0][1])
scale = cast("int", children[1][1])
return DecimalType(precision, scale)
spark_type = _sqltype_to_spark_class[id]
return spark_type()
def duckdb_to_spark_schema(names: list[str], types: list[DuckDBPyType]) -> StructType: # noqa: D103
fields = [StructField(name, dtype) for name, dtype in zip(names, [convert_type(x) for x in types])]
return StructType(fields)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,37 @@
# https://sparkbyexamples.com/pyspark/pyspark-udf-user-defined-function/ # noqa: D100
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
from .types import DataType
if TYPE_CHECKING:
from .session import SparkSession
DataTypeOrString = Union[DataType, str]
UserDefinedFunctionLike = TypeVar("UserDefinedFunctionLike")
class UDFRegistration: # noqa: D101
def __init__(self, sparkSession: "SparkSession") -> None: # noqa: D107
self.sparkSession = sparkSession
def register( # noqa: D102
self,
name: str,
f: Union[Callable[..., Any], "UserDefinedFunctionLike"],
returnType: Optional["DataTypeOrString"] = None,
) -> "UserDefinedFunctionLike":
self.sparkSession.conn.create_function(name, f, return_type=returnType)
def registerJavaFunction( # noqa: D102
self,
name: str,
javaClassName: str,
returnType: Optional["DataTypeOrString"] = None,
) -> None:
raise NotImplementedError
def registerJavaUDAF(self, name: str, javaClassName: str) -> None: # noqa: D102
raise NotImplementedError
__all__ = ["UDFRegistration"]