initial commit
This commit is contained in:
@@ -0,0 +1,7 @@
|
||||
from .catalog import Catalog # noqa: D104
|
||||
from .conf import RuntimeConfig
|
||||
from .dataframe import DataFrame
|
||||
from .readwriter import DataFrameWriter
|
||||
from .session import SparkSession
|
||||
|
||||
__all__ = ["Catalog", "DataFrame", "DataFrameWriter", "RuntimeConfig", "SparkSession"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,86 @@
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
try:
|
||||
from typing import Literal, Protocol
|
||||
except ImportError:
|
||||
from typing_extensions import Literal, Protocol
|
||||
|
||||
import datetime
|
||||
import decimal
|
||||
|
||||
from .._typing import PrimitiveType
|
||||
from . import types
|
||||
from .column import Column
|
||||
|
||||
ColumnOrName = Union[Column, str]
|
||||
ColumnOrName_ = TypeVar("ColumnOrName_", bound=ColumnOrName)
|
||||
DecimalLiteral = decimal.Decimal
|
||||
DateTimeLiteral = Union[datetime.datetime, datetime.date]
|
||||
LiteralType = PrimitiveType
|
||||
AtomicDataTypeOrString = Union[types.AtomicType, str]
|
||||
DataTypeOrString = Union[types.DataType, str]
|
||||
OptionalPrimitiveType = Optional[PrimitiveType]
|
||||
|
||||
AtomicValue = TypeVar(
|
||||
"AtomicValue",
|
||||
datetime.datetime,
|
||||
datetime.date,
|
||||
decimal.Decimal,
|
||||
bool,
|
||||
str,
|
||||
int,
|
||||
float,
|
||||
)
|
||||
|
||||
RowLike = TypeVar("RowLike", list[Any], tuple[Any, ...], types.Row)
|
||||
|
||||
SQLBatchedUDFType = Literal[100]
|
||||
|
||||
|
||||
class SupportsOpen(Protocol):
|
||||
def open(self, partition_id: int, epoch_id: int) -> bool: ...
|
||||
|
||||
|
||||
class SupportsProcess(Protocol):
|
||||
def process(self, row: types.Row) -> None: ...
|
||||
|
||||
|
||||
class SupportsClose(Protocol):
|
||||
def close(self, error: Exception) -> None: ...
|
||||
|
||||
|
||||
class UserDefinedFunctionLike(Protocol):
|
||||
func: Callable[..., Any]
|
||||
evalType: int
|
||||
deterministic: bool
|
||||
|
||||
@property
|
||||
def returnType(self) -> types.DataType: ...
|
||||
|
||||
def __call__(self, *args: ColumnOrName) -> Column: ...
|
||||
|
||||
def asNondeterministic(self) -> "UserDefinedFunctionLike": ...
|
||||
@@ -0,0 +1,79 @@
|
||||
from typing import NamedTuple, Optional, Union # noqa: D100
|
||||
|
||||
from .session import SparkSession
|
||||
|
||||
|
||||
class Database(NamedTuple): # noqa: D101
|
||||
name: str
|
||||
description: Optional[str]
|
||||
locationUri: str
|
||||
|
||||
|
||||
class Table(NamedTuple): # noqa: D101
|
||||
name: str
|
||||
database: Optional[str]
|
||||
description: Optional[str]
|
||||
tableType: str
|
||||
isTemporary: bool
|
||||
|
||||
|
||||
class Column(NamedTuple): # noqa: D101
|
||||
name: str
|
||||
description: Optional[str]
|
||||
dataType: str
|
||||
nullable: bool
|
||||
isPartition: bool
|
||||
isBucket: bool
|
||||
|
||||
|
||||
class Function(NamedTuple): # noqa: D101
|
||||
name: str
|
||||
description: Optional[str]
|
||||
className: str
|
||||
isTemporary: bool
|
||||
|
||||
|
||||
class Catalog: # noqa: D101
|
||||
def __init__(self, session: SparkSession) -> None: # noqa: D107
|
||||
self._session = session
|
||||
|
||||
def listDatabases(self) -> list[Database]: # noqa: D102
|
||||
res = self._session.conn.sql("select database_name from duckdb_databases()").fetchall()
|
||||
|
||||
def transform_to_database(x: list[str]) -> Database:
|
||||
return Database(name=x[0], description=None, locationUri="")
|
||||
|
||||
databases = [transform_to_database(x) for x in res]
|
||||
return databases
|
||||
|
||||
def listTables(self) -> list[Table]: # noqa: D102
|
||||
res = self._session.conn.sql("select table_name, database_name, sql, temporary from duckdb_tables()").fetchall()
|
||||
|
||||
def transform_to_table(x: list[str]) -> Table:
|
||||
return Table(name=x[0], database=x[1], description=x[2], tableType="", isTemporary=x[3])
|
||||
|
||||
tables = [transform_to_table(x) for x in res]
|
||||
return tables
|
||||
|
||||
def listColumns(self, tableName: str, dbName: Optional[str] = None) -> list[Column]: # noqa: D102
|
||||
query = f"""
|
||||
select column_name, data_type, is_nullable from duckdb_columns() where table_name = '{tableName}'
|
||||
"""
|
||||
if dbName:
|
||||
query += f" and database_name = '{dbName}'"
|
||||
res = self._session.conn.sql(query).fetchall()
|
||||
|
||||
def transform_to_column(x: list[Union[str, bool]]) -> Column:
|
||||
return Column(name=x[0], description=None, dataType=x[1], nullable=x[2], isPartition=False, isBucket=False)
|
||||
|
||||
columns = [transform_to_column(x) for x in res]
|
||||
return columns
|
||||
|
||||
def listFunctions(self, dbName: Optional[str] = None) -> list[Function]: # noqa: D102
|
||||
raise NotImplementedError
|
||||
|
||||
def setCurrentDatabase(self, dbName: str) -> None: # noqa: D102
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
__all__ = ["Catalog", "Column", "Database", "Function", "Table"]
|
||||
361
venv/Lib/site-packages/duckdb/experimental/spark/sql/column.py
Normal file
361
venv/Lib/site-packages/duckdb/experimental/spark/sql/column.py
Normal file
@@ -0,0 +1,361 @@
|
||||
from collections.abc import Iterable # noqa: D100
|
||||
from typing import TYPE_CHECKING, Any, Callable, Union, cast
|
||||
|
||||
from ..exception import ContributionsAcceptedError
|
||||
from .types import DataType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._typing import DateTimeLiteral, DecimalLiteral, LiteralType
|
||||
|
||||
from duckdb import ColumnExpression, ConstantExpression, Expression, FunctionExpression
|
||||
from duckdb.sqltypes import DuckDBPyType
|
||||
|
||||
__all__ = ["Column"]
|
||||
|
||||
|
||||
def _get_expr(x: Union["Column", str]) -> Expression:
|
||||
return x.expr if isinstance(x, Column) else ConstantExpression(x)
|
||||
|
||||
|
||||
def _func_op(name: str, doc: str = "") -> Callable[["Column"], "Column"]:
|
||||
def _(self: "Column") -> "Column":
|
||||
njc = getattr(self.expr, name)()
|
||||
return Column(njc)
|
||||
|
||||
_.__doc__ = doc
|
||||
return _
|
||||
|
||||
|
||||
def _unary_op(
|
||||
name: str,
|
||||
doc: str = "unary operator",
|
||||
) -> Callable[["Column"], "Column"]:
|
||||
"""Create a method for given unary operator."""
|
||||
|
||||
def _(self: "Column") -> "Column":
|
||||
# Call the function identified by 'name' on the internal Expression object
|
||||
expr = getattr(self.expr, name)()
|
||||
return Column(expr)
|
||||
|
||||
_.__doc__ = doc
|
||||
return _
|
||||
|
||||
|
||||
def _bin_op(
|
||||
name: str,
|
||||
doc: str = "binary operator",
|
||||
) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"]], "Column"]:
|
||||
"""Create a method for given binary operator."""
|
||||
|
||||
def _(
|
||||
self: "Column",
|
||||
other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
|
||||
) -> "Column":
|
||||
jc = _get_expr(other)
|
||||
njc = getattr(self.expr, name)(jc)
|
||||
return Column(njc)
|
||||
|
||||
_.__doc__ = doc
|
||||
return _
|
||||
|
||||
|
||||
def _bin_func(
|
||||
name: str,
|
||||
doc: str = "binary function",
|
||||
) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"]], "Column"]:
|
||||
"""Create a function expression for the given binary function."""
|
||||
|
||||
def _(
|
||||
self: "Column",
|
||||
other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
|
||||
) -> "Column":
|
||||
other = _get_expr(other)
|
||||
func = FunctionExpression(name, self.expr, other)
|
||||
return Column(func)
|
||||
|
||||
_.__doc__ = doc
|
||||
return _
|
||||
|
||||
|
||||
class Column:
|
||||
"""A column in a DataFrame.
|
||||
|
||||
:class:`Column` instances can be created by::
|
||||
|
||||
# 1. Select a column out of a DataFrame
|
||||
|
||||
df.colName
|
||||
df["colName"]
|
||||
|
||||
# 2. Create from an expression
|
||||
df.colName + 1
|
||||
1 / df.colName
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
"""
|
||||
|
||||
def __init__(self, expr: Expression) -> None: # noqa: D107
|
||||
self.expr = expr
|
||||
|
||||
# arithmetic operators
|
||||
def __neg__(self) -> "Column": # noqa: D105
|
||||
return Column(-self.expr)
|
||||
|
||||
# `and`, `or`, `not` cannot be overloaded in Python,
|
||||
# so use bitwise operators as boolean operators
|
||||
__and__ = _bin_op("__and__")
|
||||
__or__ = _bin_op("__or__")
|
||||
__invert__ = _func_op("__invert__")
|
||||
__rand__ = _bin_op("__rand__")
|
||||
__ror__ = _bin_op("__ror__")
|
||||
|
||||
__add__ = _bin_op("__add__")
|
||||
|
||||
__sub__ = _bin_op("__sub__")
|
||||
|
||||
__mul__ = _bin_op("__mul__")
|
||||
|
||||
__div__ = _bin_op("__div__")
|
||||
|
||||
__truediv__ = _bin_op("__truediv__")
|
||||
|
||||
__mod__ = _bin_op("__mod__")
|
||||
|
||||
__pow__ = _bin_op("__pow__")
|
||||
|
||||
__radd__ = _bin_op("__radd__")
|
||||
|
||||
__rsub__ = _bin_op("__rsub__")
|
||||
|
||||
__rmul__ = _bin_op("__rmul__")
|
||||
|
||||
__rdiv__ = _bin_op("__rdiv__")
|
||||
|
||||
__rtruediv__ = _bin_op("__rtruediv__")
|
||||
|
||||
__rmod__ = _bin_op("__rmod__")
|
||||
|
||||
__rpow__ = _bin_op("__rpow__")
|
||||
|
||||
def __getitem__(self, k: Any) -> "Column": # noqa: ANN401
|
||||
"""An expression that gets an item at position ``ordinal`` out of a list,
|
||||
or gets an item by key out of a dict.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
.. versionchanged:: 3.4.0
|
||||
Supports Spark Connect.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
k
|
||||
a literal value, or a slice object without step.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
:class:`Column`
|
||||
Column representing the item got by key out of a dict, or substrings sliced by
|
||||
the given slice object.
|
||||
|
||||
Examples:
|
||||
--------
|
||||
>>> df = spark.createDataFrame([("abcedfg", {"key": "value"})], ["l", "d"])
|
||||
>>> df.select(df.l[slice(1, 3)], df.d["key"]).show()
|
||||
+------------------+------+
|
||||
|substring(l, 1, 3)|d[key]|
|
||||
+------------------+------+
|
||||
| abc| value|
|
||||
+------------------+------+
|
||||
""" # noqa: D205
|
||||
if isinstance(k, slice):
|
||||
raise ContributionsAcceptedError
|
||||
# if k.step is not None:
|
||||
# raise ValueError("Using a slice with a step value is not supported")
|
||||
# return self.substr(k.start, k.stop)
|
||||
else:
|
||||
# TODO: this is super hacky # noqa: TD002, TD003
|
||||
expr_str = str(self.expr) + "." + str(k)
|
||||
return Column(ColumnExpression(expr_str))
|
||||
|
||||
def __getattr__(self, item: Any) -> "Column": # noqa: ANN401
|
||||
"""An expression that gets an item at position ``ordinal`` out of a list,
|
||||
or gets an item by key out of a dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
item
|
||||
a literal value.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
:class:`Column`
|
||||
Column representing the item got by key out of a dict.
|
||||
|
||||
Examples:
|
||||
--------
|
||||
>>> df = spark.createDataFrame([("abcedfg", {"key": "value"})], ["l", "d"])
|
||||
>>> df.select(df.d.key).show()
|
||||
+------+
|
||||
|d[key]|
|
||||
+------+
|
||||
| value|
|
||||
+------+
|
||||
""" # noqa: D205
|
||||
if item.startswith("__"):
|
||||
msg = "Can not access __ (dunder) method"
|
||||
raise AttributeError(msg)
|
||||
return self[item]
|
||||
|
||||
def alias(self, alias: str) -> "Column": # noqa: D102
|
||||
return Column(self.expr.alias(alias))
|
||||
|
||||
def when(self, condition: "Column", value: Union["Column", str]) -> "Column": # noqa: D102
|
||||
if not isinstance(condition, Column):
|
||||
msg = "condition should be a Column"
|
||||
raise TypeError(msg)
|
||||
v = _get_expr(value)
|
||||
expr = self.expr.when(condition.expr, v)
|
||||
return Column(expr)
|
||||
|
||||
def otherwise(self, value: Union["Column", str]) -> "Column": # noqa: D102
|
||||
v = _get_expr(value)
|
||||
expr = self.expr.otherwise(v)
|
||||
return Column(expr)
|
||||
|
||||
def cast(self, dataType: Union[DataType, str]) -> "Column": # noqa: D102
|
||||
internal_type = DuckDBPyType(dataType) if isinstance(dataType, str) else dataType.duckdb_type
|
||||
return Column(self.expr.cast(internal_type))
|
||||
|
||||
def isin(self, *cols: Union[Iterable[Union["Column", str]], Union["Column", str]]) -> "Column": # noqa: D102
|
||||
if len(cols) == 1 and isinstance(cols[0], (list, set)):
|
||||
# Only one argument supplied, it's a list
|
||||
cols = cast("tuple", cols[0])
|
||||
|
||||
cols = cast(
|
||||
"tuple",
|
||||
[_get_expr(c) for c in cols],
|
||||
)
|
||||
return Column(self.expr.isin(*cols))
|
||||
|
||||
# logistic operators
|
||||
def __eq__( # type: ignore[override]
|
||||
self,
|
||||
other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
|
||||
) -> "Column":
|
||||
"""Binary function."""
|
||||
return Column(self.expr == (_get_expr(other)))
|
||||
|
||||
def __ne__( # type: ignore[override]
|
||||
self,
|
||||
other: object,
|
||||
) -> "Column":
|
||||
"""Binary function."""
|
||||
return Column(self.expr != (_get_expr(other)))
|
||||
|
||||
__lt__ = _bin_op("__lt__")
|
||||
|
||||
__le__ = _bin_op("__le__")
|
||||
|
||||
__ge__ = _bin_op("__ge__")
|
||||
|
||||
__gt__ = _bin_op("__gt__")
|
||||
|
||||
# String interrogation methods
|
||||
|
||||
contains = _bin_func("contains")
|
||||
rlike = _bin_func("regexp_matches")
|
||||
like = _bin_func("~~")
|
||||
ilike = _bin_func("~~*")
|
||||
startswith = _bin_func("starts_with")
|
||||
endswith = _bin_func("suffix")
|
||||
|
||||
# order
|
||||
_asc_doc = """
|
||||
Returns a sort expression based on the ascending order of the column.
|
||||
Examples
|
||||
--------
|
||||
>>> from pyspark.sql import Row
|
||||
>>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"])
|
||||
>>> df.select(df.name).orderBy(df.name.asc()).collect()
|
||||
[Row(name='Alice'), Row(name='Tom')]
|
||||
"""
|
||||
|
||||
_asc_nulls_first_doc = """
|
||||
Returns a sort expression based on ascending order of the column, and null values
|
||||
return before non-null values.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from pyspark.sql import Row
|
||||
>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
|
||||
>>> df.select(df.name).orderBy(df.name.asc_nulls_first()).collect()
|
||||
[Row(name=None), Row(name='Alice'), Row(name='Tom')]
|
||||
|
||||
"""
|
||||
_asc_nulls_last_doc = """
|
||||
Returns a sort expression based on ascending order of the column, and null values
|
||||
appear after non-null values.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from pyspark.sql import Row
|
||||
>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
|
||||
>>> df.select(df.name).orderBy(df.name.asc_nulls_last()).collect()
|
||||
[Row(name='Alice'), Row(name='Tom'), Row(name=None)]
|
||||
|
||||
"""
|
||||
_desc_doc = """
|
||||
Returns a sort expression based on the descending order of the column.
|
||||
Examples
|
||||
--------
|
||||
>>> from pyspark.sql import Row
|
||||
>>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"])
|
||||
>>> df.select(df.name).orderBy(df.name.desc()).collect()
|
||||
[Row(name='Tom'), Row(name='Alice')]
|
||||
"""
|
||||
_desc_nulls_first_doc = """
|
||||
Returns a sort expression based on the descending order of the column, and null values
|
||||
appear before non-null values.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from pyspark.sql import Row
|
||||
>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
|
||||
>>> df.select(df.name).orderBy(df.name.desc_nulls_first()).collect()
|
||||
[Row(name=None), Row(name='Tom'), Row(name='Alice')]
|
||||
|
||||
"""
|
||||
_desc_nulls_last_doc = """
|
||||
Returns a sort expression based on the descending order of the column, and null values
|
||||
appear after non-null values.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from pyspark.sql import Row
|
||||
>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
|
||||
>>> df.select(df.name).orderBy(df.name.desc_nulls_last()).collect()
|
||||
[Row(name='Tom'), Row(name='Alice'), Row(name=None)]
|
||||
"""
|
||||
|
||||
asc = _unary_op("asc", _asc_doc)
|
||||
desc = _unary_op("desc", _desc_doc)
|
||||
nulls_first = _unary_op("nulls_first")
|
||||
nulls_last = _unary_op("nulls_last")
|
||||
|
||||
def asc_nulls_first(self) -> "Column": # noqa: D102
|
||||
return self.asc().nulls_first()
|
||||
|
||||
def asc_nulls_last(self) -> "Column": # noqa: D102
|
||||
return self.asc().nulls_last()
|
||||
|
||||
def desc_nulls_first(self) -> "Column": # noqa: D102
|
||||
return self.desc().nulls_first()
|
||||
|
||||
def desc_nulls_last(self) -> "Column": # noqa: D102
|
||||
return self.desc().nulls_last()
|
||||
|
||||
def isNull(self) -> "Column": # noqa: D102
|
||||
return Column(self.expr.isnull())
|
||||
|
||||
def isNotNull(self) -> "Column": # noqa: D102
|
||||
return Column(self.expr.isnotnull())
|
||||
24
venv/Lib/site-packages/duckdb/experimental/spark/sql/conf.py
Normal file
24
venv/Lib/site-packages/duckdb/experimental/spark/sql/conf.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import Optional, Union # noqa: D100
|
||||
|
||||
from duckdb import DuckDBPyConnection
|
||||
from duckdb.experimental.spark._globals import _NoValue, _NoValueType
|
||||
|
||||
|
||||
class RuntimeConfig: # noqa: D101
|
||||
def __init__(self, connection: DuckDBPyConnection) -> None: # noqa: D107
|
||||
self._connection = connection
|
||||
|
||||
def set(self, key: str, value: str) -> None: # noqa: D102
|
||||
raise NotImplementedError
|
||||
|
||||
def isModifiable(self, key: str) -> bool: # noqa: D102
|
||||
raise NotImplementedError
|
||||
|
||||
def unset(self, key: str) -> None: # noqa: D102
|
||||
raise NotImplementedError
|
||||
|
||||
def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue) -> str: # noqa: D102
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
__all__ = ["RuntimeConfig"]
|
||||
1423
venv/Lib/site-packages/duckdb/experimental/spark/sql/dataframe.py
Normal file
1423
venv/Lib/site-packages/duckdb/experimental/spark/sql/dataframe.py
Normal file
File diff suppressed because it is too large
Load Diff
6216
venv/Lib/site-packages/duckdb/experimental/spark/sql/functions.py
Normal file
6216
venv/Lib/site-packages/duckdb/experimental/spark/sql/functions.py
Normal file
File diff suppressed because it is too large
Load Diff
424
venv/Lib/site-packages/duckdb/experimental/spark/sql/group.py
Normal file
424
venv/Lib/site-packages/duckdb/experimental/spark/sql/group.py
Normal file
@@ -0,0 +1,424 @@
|
||||
# # noqa: D100
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, Union, overload
|
||||
|
||||
from ..exception import ContributionsAcceptedError
|
||||
from .column import Column
|
||||
from .dataframe import DataFrame
|
||||
from .functions import _to_column_expr
|
||||
from .types import NumericType
|
||||
|
||||
# Only import symbols needed for type checking if something is type checking
|
||||
if TYPE_CHECKING:
|
||||
from ._typing import ColumnOrName
|
||||
from .session import SparkSession
|
||||
|
||||
__all__ = ["GroupedData", "Grouping"]
|
||||
|
||||
|
||||
def _api_internal(self: "GroupedData", name: str, *cols: str) -> DataFrame:
|
||||
expressions = ",".join(list(cols))
|
||||
group_by = str(self._grouping) if self._grouping else ""
|
||||
projections = self._grouping.get_columns()
|
||||
jdf = self._df.relation.apply(
|
||||
function_name=name, # aggregate function
|
||||
function_aggr=expressions, # inputs to aggregate
|
||||
group_expr=group_by, # groups
|
||||
projected_columns=projections, # projections
|
||||
)
|
||||
return DataFrame(jdf, self.session)
|
||||
|
||||
|
||||
def df_varargs_api(f: Callable[..., DataFrame]) -> Callable[..., DataFrame]:
|
||||
def _api(self: "GroupedData", *cols: str) -> DataFrame:
|
||||
name = f.__name__
|
||||
return _api_internal(self, name, *cols)
|
||||
|
||||
_api.__name__ = f.__name__
|
||||
_api.__doc__ = f.__doc__
|
||||
return _api
|
||||
|
||||
|
||||
class Grouping: # noqa: D101
|
||||
def __init__(self, *cols: "ColumnOrName", **kwargs) -> None: # noqa: D107
|
||||
self._type = ""
|
||||
self._cols = [_to_column_expr(x) for x in cols]
|
||||
if "special" in kwargs:
|
||||
special = kwargs["special"]
|
||||
accepted_special = ["cube", "rollup"]
|
||||
assert special in accepted_special
|
||||
self._type = special
|
||||
|
||||
def get_columns(self) -> str: # noqa: D102
|
||||
columns = ",".join([str(x) for x in self._cols])
|
||||
return columns
|
||||
|
||||
def __str__(self) -> str: # noqa: D105
|
||||
columns = self.get_columns()
|
||||
if self._type:
|
||||
return self._type + "(" + columns + ")"
|
||||
return columns
|
||||
|
||||
|
||||
class GroupedData:
|
||||
"""A set of methods for aggregations on a :class:`DataFrame`,
|
||||
created by :func:`DataFrame.groupBy`.
|
||||
|
||||
""" # noqa: D205
|
||||
|
||||
def __init__(self, grouping: Grouping, df: DataFrame) -> None: # noqa: D107
|
||||
self._grouping = grouping
|
||||
self._df = df
|
||||
self.session: SparkSession = df.session
|
||||
|
||||
def __repr__(self) -> str: # noqa: D105
|
||||
return str(self._df)
|
||||
|
||||
def count(self) -> DataFrame:
|
||||
"""Counts the number of records for each group.
|
||||
|
||||
Examples:
|
||||
--------
|
||||
>>> df = spark.createDataFrame(
|
||||
... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"]
|
||||
... )
|
||||
>>> df.show()
|
||||
+---+-----+
|
||||
|age| name|
|
||||
+---+-----+
|
||||
| 2|Alice|
|
||||
| 3|Alice|
|
||||
| 5| Bob|
|
||||
| 10| Bob|
|
||||
+---+-----+
|
||||
|
||||
Group-by name, and count each group.
|
||||
|
||||
>>> df.groupBy(df.name).count().sort("name").show()
|
||||
+-----+-----+
|
||||
| name|count|
|
||||
+-----+-----+
|
||||
|Alice| 2|
|
||||
| Bob| 2|
|
||||
+-----+-----+
|
||||
"""
|
||||
return _api_internal(self, "count").withColumnRenamed("count_star()", "count")
|
||||
|
||||
@df_varargs_api
|
||||
def mean(self, *cols: str) -> DataFrame:
|
||||
"""Computes average values for each numeric columns for each group.
|
||||
|
||||
:func:`mean` is an alias for :func:`avg`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cols : str
|
||||
column names. Non-numeric columns are ignored.
|
||||
"""
|
||||
|
||||
def avg(self, *cols: str) -> DataFrame:
|
||||
"""Computes average values for each numeric columns for each group.
|
||||
|
||||
:func:`mean` is an alias for :func:`avg`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cols : str
|
||||
column names. Non-numeric columns are ignored.
|
||||
|
||||
Examples:
|
||||
--------
|
||||
>>> df = spark.createDataFrame(
|
||||
... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)],
|
||||
... ["age", "name", "height"],
|
||||
... )
|
||||
>>> df.show()
|
||||
+---+-----+------+
|
||||
|age| name|height|
|
||||
+---+-----+------+
|
||||
| 2|Alice| 80|
|
||||
| 3|Alice| 100|
|
||||
| 5| Bob| 120|
|
||||
| 10| Bob| 140|
|
||||
+---+-----+------+
|
||||
|
||||
Group-by name, and calculate the mean of the age in each group.
|
||||
|
||||
>>> df.groupBy("name").avg("age").sort("name").show()
|
||||
+-----+--------+
|
||||
| name|avg(age)|
|
||||
+-----+--------+
|
||||
|Alice| 2.5|
|
||||
| Bob| 7.5|
|
||||
+-----+--------+
|
||||
|
||||
Calculate the mean of the age and height in all data.
|
||||
|
||||
>>> df.groupBy().avg("age", "height").show()
|
||||
+--------+-----------+
|
||||
|avg(age)|avg(height)|
|
||||
+--------+-----------+
|
||||
| 5.0| 110.0|
|
||||
+--------+-----------+
|
||||
"""
|
||||
columns = list(cols)
|
||||
if len(columns) == 0:
|
||||
schema = self._df.schema
|
||||
# Take only the numeric types of the relation
|
||||
columns: list[str] = [x.name for x in schema.fields if isinstance(x.dataType, NumericType)]
|
||||
return _api_internal(self, "avg", *columns)
|
||||
|
||||
@df_varargs_api
|
||||
def max(self, *cols: str) -> DataFrame:
|
||||
"""Computes the max value for each numeric columns for each group.
|
||||
|
||||
Examples:
|
||||
--------
|
||||
>>> df = spark.createDataFrame(
|
||||
... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)],
|
||||
... ["age", "name", "height"],
|
||||
... )
|
||||
>>> df.show()
|
||||
+---+-----+------+
|
||||
|age| name|height|
|
||||
+---+-----+------+
|
||||
| 2|Alice| 80|
|
||||
| 3|Alice| 100|
|
||||
| 5| Bob| 120|
|
||||
| 10| Bob| 140|
|
||||
+---+-----+------+
|
||||
|
||||
Group-by name, and calculate the max of the age in each group.
|
||||
|
||||
>>> df.groupBy("name").max("age").sort("name").show()
|
||||
+-----+--------+
|
||||
| name|max(age)|
|
||||
+-----+--------+
|
||||
|Alice| 3|
|
||||
| Bob| 10|
|
||||
+-----+--------+
|
||||
|
||||
Calculate the max of the age and height in all data.
|
||||
|
||||
>>> df.groupBy().max("age", "height").show()
|
||||
+--------+-----------+
|
||||
|max(age)|max(height)|
|
||||
+--------+-----------+
|
||||
| 10| 140|
|
||||
+--------+-----------+
|
||||
"""
|
||||
|
||||
@df_varargs_api
|
||||
def min(self, *cols: str) -> DataFrame:
|
||||
"""Computes the min value for each numeric column for each group.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cols : str
|
||||
column names. Non-numeric columns are ignored.
|
||||
|
||||
Examples:
|
||||
--------
|
||||
>>> df = spark.createDataFrame(
|
||||
... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)],
|
||||
... ["age", "name", "height"],
|
||||
... )
|
||||
>>> df.show()
|
||||
+---+-----+------+
|
||||
|age| name|height|
|
||||
+---+-----+------+
|
||||
| 2|Alice| 80|
|
||||
| 3|Alice| 100|
|
||||
| 5| Bob| 120|
|
||||
| 10| Bob| 140|
|
||||
+---+-----+------+
|
||||
|
||||
Group-by name, and calculate the min of the age in each group.
|
||||
|
||||
>>> df.groupBy("name").min("age").sort("name").show()
|
||||
+-----+--------+
|
||||
| name|min(age)|
|
||||
+-----+--------+
|
||||
|Alice| 2|
|
||||
| Bob| 5|
|
||||
+-----+--------+
|
||||
|
||||
Calculate the min of the age and height in all data.
|
||||
|
||||
>>> df.groupBy().min("age", "height").show()
|
||||
+--------+-----------+
|
||||
|min(age)|min(height)|
|
||||
+--------+-----------+
|
||||
| 2| 80|
|
||||
+--------+-----------+
|
||||
"""
|
||||
|
||||
@df_varargs_api
|
||||
def sum(self, *cols: str) -> DataFrame:
|
||||
"""Computes the sum for each numeric columns for each group.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cols : str
|
||||
column names. Non-numeric columns are ignored.
|
||||
|
||||
Examples:
|
||||
--------
|
||||
>>> df = spark.createDataFrame(
|
||||
... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)],
|
||||
... ["age", "name", "height"],
|
||||
... )
|
||||
>>> df.show()
|
||||
+---+-----+------+
|
||||
|age| name|height|
|
||||
+---+-----+------+
|
||||
| 2|Alice| 80|
|
||||
| 3|Alice| 100|
|
||||
| 5| Bob| 120|
|
||||
| 10| Bob| 140|
|
||||
+---+-----+------+
|
||||
|
||||
Group-by name, and calculate the sum of the age in each group.
|
||||
|
||||
>>> df.groupBy("name").sum("age").sort("name").show()
|
||||
+-----+--------+
|
||||
| name|sum(age)|
|
||||
+-----+--------+
|
||||
|Alice| 5|
|
||||
| Bob| 15|
|
||||
+-----+--------+
|
||||
|
||||
Calculate the sum of the age and height in all data.
|
||||
|
||||
>>> df.groupBy().sum("age", "height").show()
|
||||
+--------+-----------+
|
||||
|sum(age)|sum(height)|
|
||||
+--------+-----------+
|
||||
| 20| 440|
|
||||
+--------+-----------+
|
||||
"""
|
||||
|
||||
@overload
|
||||
def agg(self, *exprs: Column) -> DataFrame: ...
|
||||
|
||||
@overload
|
||||
def agg(self, __exprs: dict[str, str]) -> DataFrame: ... # noqa: PYI063
|
||||
|
||||
def agg(self, *exprs: Union[Column, dict[str, str]]) -> DataFrame:
|
||||
"""Compute aggregates and returns the result as a :class:`DataFrame`.
|
||||
|
||||
The available aggregate functions can be:
|
||||
|
||||
1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count`
|
||||
|
||||
2. group aggregate pandas UDFs, created with :func:`pyspark.sql.functions.pandas_udf`
|
||||
|
||||
.. note:: There is no partial aggregation with group aggregate UDFs, i.e.,
|
||||
a full shuffle is required. Also, all the data of a group will be loaded into
|
||||
memory, so the user should be aware of the potential OOM risk if data is skewed
|
||||
and certain groups are too large to fit in memory.
|
||||
|
||||
.. seealso:: :func:`pyspark.sql.functions.pandas_udf`
|
||||
|
||||
If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
|
||||
is the column to perform aggregation on, and the value is the aggregate function.
|
||||
|
||||
Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
.. versionchanged:: 3.4.0
|
||||
Supports Spark Connect.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
exprs : dict
|
||||
a dict mapping from column name (string) to aggregate functions (string),
|
||||
or a list of :class:`Column`.
|
||||
|
||||
Notes:
|
||||
-----
|
||||
Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed
|
||||
in a single call to this function.
|
||||
|
||||
Examples:
|
||||
--------
|
||||
>>> from pyspark.sql import functions as F
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> df = spark.createDataFrame(
|
||||
... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"]
|
||||
... )
|
||||
>>> df.show()
|
||||
+---+-----+
|
||||
|age| name|
|
||||
+---+-----+
|
||||
| 2|Alice|
|
||||
| 3|Alice|
|
||||
| 5| Bob|
|
||||
| 10| Bob|
|
||||
+---+-----+
|
||||
|
||||
Group-by name, and count each group.
|
||||
|
||||
>>> df.groupBy(df.name)
|
||||
GroupedData[grouping...: [name...], value: [age: bigint, name: string], type: GroupBy]
|
||||
|
||||
>>> df.groupBy(df.name).agg({"*": "count"}).sort("name").show()
|
||||
+-----+--------+
|
||||
| name|count(1)|
|
||||
+-----+--------+
|
||||
|Alice| 2|
|
||||
| Bob| 2|
|
||||
+-----+--------+
|
||||
|
||||
Group-by name, and calculate the minimum age.
|
||||
|
||||
>>> df.groupBy(df.name).agg(F.min(df.age)).sort("name").show()
|
||||
+-----+--------+
|
||||
| name|min(age)|
|
||||
+-----+--------+
|
||||
|Alice| 2|
|
||||
| Bob| 5|
|
||||
+-----+--------+
|
||||
|
||||
Same as above but uses pandas UDF.
|
||||
|
||||
>>> @pandas_udf("int", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
|
||||
... def min_udf(v):
|
||||
... return v.min()
|
||||
>>> df.groupBy(df.name).agg(min_udf(df.age)).sort("name").show() # doctest: +SKIP
|
||||
+-----+------------+
|
||||
| name|min_udf(age)|
|
||||
+-----+------------+
|
||||
|Alice| 2|
|
||||
| Bob| 5|
|
||||
+-----+------------+
|
||||
"""
|
||||
assert exprs, "exprs should not be empty"
|
||||
if len(exprs) == 1 and isinstance(exprs[0], dict):
|
||||
raise ContributionsAcceptedError
|
||||
else:
|
||||
# Columns
|
||||
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
|
||||
expressions = list(self._grouping._cols)
|
||||
expressions.extend([x.expr for x in exprs])
|
||||
group_by = str(self._grouping)
|
||||
rel = self._df.relation.select(*expressions, groups=group_by)
|
||||
return DataFrame(rel, self.session)
|
||||
|
||||
# TODO: add 'pivot' # noqa: TD002, TD003
|
||||
@@ -0,0 +1,435 @@
|
||||
from typing import TYPE_CHECKING, Optional, Union, cast # noqa: D100
|
||||
|
||||
from ..errors import PySparkNotImplementedError, PySparkTypeError
|
||||
from ..exception import ContributionsAcceptedError
|
||||
from .types import StructType
|
||||
|
||||
PrimitiveType = Union[bool, float, int, str]
|
||||
OptionalPrimitiveType = Optional[PrimitiveType]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from duckdb.experimental.spark.sql.dataframe import DataFrame
|
||||
from duckdb.experimental.spark.sql.session import SparkSession
|
||||
|
||||
|
||||
class DataFrameWriter: # noqa: D101
|
||||
def __init__(self, dataframe: "DataFrame") -> None: # noqa: D107
|
||||
self.dataframe = dataframe
|
||||
|
||||
def saveAsTable(self, table_name: str) -> None: # noqa: D102
|
||||
relation = self.dataframe.relation
|
||||
relation.create(table_name)
|
||||
|
||||
def parquet( # noqa: D102
|
||||
self,
|
||||
path: str,
|
||||
mode: Optional[str] = None,
|
||||
partitionBy: Union[str, list[str], None] = None,
|
||||
compression: Optional[str] = None,
|
||||
) -> None:
|
||||
relation = self.dataframe.relation
|
||||
if mode:
|
||||
raise NotImplementedError
|
||||
if partitionBy:
|
||||
raise NotImplementedError
|
||||
|
||||
relation.write_parquet(path, compression=compression)
|
||||
|
||||
def csv( # noqa: D102
|
||||
self,
|
||||
path: str,
|
||||
mode: Optional[str] = None,
|
||||
compression: Optional[str] = None,
|
||||
sep: Optional[str] = None,
|
||||
quote: Optional[str] = None,
|
||||
escape: Optional[str] = None,
|
||||
header: Optional[Union[bool, str]] = None,
|
||||
nullValue: Optional[str] = None,
|
||||
escapeQuotes: Optional[Union[bool, str]] = None,
|
||||
quoteAll: Optional[Union[bool, str]] = None,
|
||||
dateFormat: Optional[str] = None,
|
||||
timestampFormat: Optional[str] = None,
|
||||
ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = None,
|
||||
ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = None,
|
||||
charToEscapeQuoteEscaping: Optional[str] = None,
|
||||
encoding: Optional[str] = None,
|
||||
emptyValue: Optional[str] = None,
|
||||
lineSep: Optional[str] = None,
|
||||
) -> None:
|
||||
if mode not in (None, "overwrite"):
|
||||
raise NotImplementedError
|
||||
if escapeQuotes:
|
||||
raise NotImplementedError
|
||||
if ignoreLeadingWhiteSpace:
|
||||
raise NotImplementedError
|
||||
if ignoreTrailingWhiteSpace:
|
||||
raise NotImplementedError
|
||||
if charToEscapeQuoteEscaping:
|
||||
raise NotImplementedError
|
||||
if emptyValue:
|
||||
raise NotImplementedError
|
||||
if lineSep:
|
||||
raise NotImplementedError
|
||||
relation = self.dataframe.relation
|
||||
relation.write_csv(
|
||||
path,
|
||||
sep=sep,
|
||||
na_rep=nullValue,
|
||||
quotechar=quote,
|
||||
compression=compression,
|
||||
escapechar=escape,
|
||||
header=header if isinstance(header, bool) else header == "True",
|
||||
encoding=encoding,
|
||||
quoting=quoteAll,
|
||||
date_format=dateFormat,
|
||||
timestamp_format=timestampFormat,
|
||||
)
|
||||
|
||||
|
||||
class DataFrameReader: # noqa: D101
|
||||
def __init__(self, session: "SparkSession") -> None: # noqa: D107
|
||||
self.session = session
|
||||
|
||||
def load( # noqa: D102
|
||||
self,
|
||||
path: Optional[Union[str, list[str]]] = None,
|
||||
format: Optional[str] = None,
|
||||
schema: Optional[Union[StructType, str]] = None,
|
||||
**options: OptionalPrimitiveType,
|
||||
) -> "DataFrame":
|
||||
from duckdb.experimental.spark.sql.dataframe import DataFrame
|
||||
|
||||
if not isinstance(path, str):
|
||||
raise TypeError
|
||||
if options:
|
||||
raise ContributionsAcceptedError
|
||||
|
||||
rel = None
|
||||
if format:
|
||||
format = format.lower()
|
||||
if format == "csv" or format == "tsv":
|
||||
rel = self.session.conn.read_csv(path)
|
||||
elif format == "json":
|
||||
rel = self.session.conn.read_json(path)
|
||||
elif format == "parquet":
|
||||
rel = self.session.conn.read_parquet(path)
|
||||
else:
|
||||
raise ContributionsAcceptedError
|
||||
else:
|
||||
rel = self.session.conn.sql(f"select * from {path}")
|
||||
df = DataFrame(rel, self.session)
|
||||
if schema:
|
||||
if not isinstance(schema, StructType):
|
||||
raise ContributionsAcceptedError
|
||||
schema = cast("StructType", schema)
|
||||
types, names = schema.extract_types_and_names()
|
||||
df = df._cast_types(types)
|
||||
df = df.toDF(names)
|
||||
return df
|
||||
|
||||
def csv( # noqa: D102
|
||||
self,
|
||||
path: Union[str, list[str]],
|
||||
schema: Optional[Union[StructType, str]] = None,
|
||||
sep: Optional[str] = None,
|
||||
encoding: Optional[str] = None,
|
||||
quote: Optional[str] = None,
|
||||
escape: Optional[str] = None,
|
||||
comment: Optional[str] = None,
|
||||
header: Optional[Union[bool, str]] = None,
|
||||
inferSchema: Optional[Union[bool, str]] = None,
|
||||
ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = None,
|
||||
ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = None,
|
||||
nullValue: Optional[str] = None,
|
||||
nanValue: Optional[str] = None,
|
||||
positiveInf: Optional[str] = None,
|
||||
negativeInf: Optional[str] = None,
|
||||
dateFormat: Optional[str] = None,
|
||||
timestampFormat: Optional[str] = None,
|
||||
maxColumns: Optional[Union[int, str]] = None,
|
||||
maxCharsPerColumn: Optional[Union[int, str]] = None,
|
||||
maxMalformedLogPerPartition: Optional[Union[int, str]] = None,
|
||||
mode: Optional[str] = None,
|
||||
columnNameOfCorruptRecord: Optional[str] = None,
|
||||
multiLine: Optional[Union[bool, str]] = None,
|
||||
charToEscapeQuoteEscaping: Optional[str] = None,
|
||||
samplingRatio: Optional[Union[float, str]] = None,
|
||||
enforceSchema: Optional[Union[bool, str]] = None,
|
||||
emptyValue: Optional[str] = None,
|
||||
locale: Optional[str] = None,
|
||||
lineSep: Optional[str] = None,
|
||||
pathGlobFilter: Optional[Union[bool, str]] = None,
|
||||
recursiveFileLookup: Optional[Union[bool, str]] = None,
|
||||
modifiedBefore: Optional[Union[bool, str]] = None,
|
||||
modifiedAfter: Optional[Union[bool, str]] = None,
|
||||
unescapedQuoteHandling: Optional[str] = None,
|
||||
) -> "DataFrame":
|
||||
if not isinstance(path, str):
|
||||
raise NotImplementedError
|
||||
if schema and not isinstance(schema, StructType):
|
||||
raise ContributionsAcceptedError
|
||||
if comment:
|
||||
raise ContributionsAcceptedError
|
||||
if inferSchema:
|
||||
raise ContributionsAcceptedError
|
||||
if ignoreLeadingWhiteSpace:
|
||||
raise ContributionsAcceptedError
|
||||
if ignoreTrailingWhiteSpace:
|
||||
raise ContributionsAcceptedError
|
||||
if nanValue:
|
||||
raise ConnectionAbortedError
|
||||
if positiveInf:
|
||||
raise ConnectionAbortedError
|
||||
if negativeInf:
|
||||
raise ConnectionAbortedError
|
||||
if negativeInf:
|
||||
raise ConnectionAbortedError
|
||||
if maxColumns:
|
||||
raise ContributionsAcceptedError
|
||||
if maxCharsPerColumn:
|
||||
raise ContributionsAcceptedError
|
||||
if maxMalformedLogPerPartition:
|
||||
raise ContributionsAcceptedError
|
||||
if mode:
|
||||
raise ContributionsAcceptedError
|
||||
if columnNameOfCorruptRecord:
|
||||
raise ContributionsAcceptedError
|
||||
if multiLine:
|
||||
raise ContributionsAcceptedError
|
||||
if charToEscapeQuoteEscaping:
|
||||
raise ContributionsAcceptedError
|
||||
if samplingRatio:
|
||||
raise ContributionsAcceptedError
|
||||
if enforceSchema:
|
||||
raise ContributionsAcceptedError
|
||||
if emptyValue:
|
||||
raise ContributionsAcceptedError
|
||||
if locale:
|
||||
raise ContributionsAcceptedError
|
||||
if pathGlobFilter:
|
||||
raise ContributionsAcceptedError
|
||||
if recursiveFileLookup:
|
||||
raise ContributionsAcceptedError
|
||||
if modifiedBefore:
|
||||
raise ContributionsAcceptedError
|
||||
if modifiedAfter:
|
||||
raise ContributionsAcceptedError
|
||||
if unescapedQuoteHandling:
|
||||
raise ContributionsAcceptedError
|
||||
if lineSep:
|
||||
# We have support for custom newline, just needs to be ported to 'read_csv'
|
||||
raise NotImplementedError
|
||||
|
||||
dtype = None
|
||||
names = None
|
||||
if schema:
|
||||
schema = cast("StructType", schema)
|
||||
dtype, names = schema.extract_types_and_names()
|
||||
|
||||
rel = self.session.conn.read_csv(
|
||||
path,
|
||||
header=header if isinstance(header, bool) else header == "True",
|
||||
sep=sep,
|
||||
dtype=dtype,
|
||||
na_values=nullValue,
|
||||
quotechar=quote,
|
||||
escapechar=escape,
|
||||
encoding=encoding,
|
||||
date_format=dateFormat,
|
||||
timestamp_format=timestampFormat,
|
||||
)
|
||||
from ..sql.dataframe import DataFrame
|
||||
|
||||
df = DataFrame(rel, self.session)
|
||||
if names:
|
||||
df = df.toDF(*names)
|
||||
return df
|
||||
|
||||
def parquet(self, *paths: str, **options: "OptionalPrimitiveType") -> "DataFrame": # noqa: D102
|
||||
input = list(paths)
|
||||
if len(input) != 1:
|
||||
msg = "Only single paths are supported for now"
|
||||
raise NotImplementedError(msg)
|
||||
option_amount = len(options.keys())
|
||||
if option_amount != 0:
|
||||
msg = "Options are not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
path = input[0]
|
||||
rel = self.session.conn.read_parquet(path)
|
||||
from ..sql.dataframe import DataFrame
|
||||
|
||||
df = DataFrame(rel, self.session)
|
||||
return df
|
||||
|
||||
def json(
|
||||
self,
|
||||
path: Union[str, list[str]],
|
||||
schema: Optional[Union[StructType, str]] = None,
|
||||
primitivesAsString: Optional[Union[bool, str]] = None,
|
||||
prefersDecimal: Optional[Union[bool, str]] = None,
|
||||
allowComments: Optional[Union[bool, str]] = None,
|
||||
allowUnquotedFieldNames: Optional[Union[bool, str]] = None,
|
||||
allowSingleQuotes: Optional[Union[bool, str]] = None,
|
||||
allowNumericLeadingZero: Optional[Union[bool, str]] = None,
|
||||
allowBackslashEscapingAnyCharacter: Optional[Union[bool, str]] = None,
|
||||
mode: Optional[str] = None,
|
||||
columnNameOfCorruptRecord: Optional[str] = None,
|
||||
dateFormat: Optional[str] = None,
|
||||
timestampFormat: Optional[str] = None,
|
||||
multiLine: Optional[Union[bool, str]] = None,
|
||||
allowUnquotedControlChars: Optional[Union[bool, str]] = None,
|
||||
lineSep: Optional[str] = None,
|
||||
samplingRatio: Optional[Union[float, str]] = None,
|
||||
dropFieldIfAllNull: Optional[Union[bool, str]] = None,
|
||||
encoding: Optional[str] = None,
|
||||
locale: Optional[str] = None,
|
||||
pathGlobFilter: Optional[Union[bool, str]] = None,
|
||||
recursiveFileLookup: Optional[Union[bool, str]] = None,
|
||||
modifiedBefore: Optional[Union[bool, str]] = None,
|
||||
modifiedAfter: Optional[Union[bool, str]] = None,
|
||||
allowNonNumericNumbers: Optional[Union[bool, str]] = None,
|
||||
) -> "DataFrame":
|
||||
"""Loads JSON files and returns the results as a :class:`DataFrame`.
|
||||
|
||||
`JSON Lines <http://jsonlines.org/>`_ (newline-delimited JSON) is supported by default.
|
||||
For JSON (one record per file), set the ``multiLine`` parameter to ``true``.
|
||||
|
||||
If the ``schema`` parameter is not specified, this function goes
|
||||
through the input once to determine the input schema.
|
||||
|
||||
.. versionadded:: 1.4.0
|
||||
|
||||
.. versionchanged:: 3.4.0
|
||||
Supports Spark Connect.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str, list or :class:`RDD`
|
||||
string represents path to the JSON dataset, or a list of paths,
|
||||
or RDD of Strings storing JSON objects.
|
||||
schema : :class:`pyspark.sql.types.StructType` or str, optional
|
||||
an optional :class:`pyspark.sql.types.StructType` for the input schema or
|
||||
a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
|
||||
|
||||
Other Parameters
|
||||
----------------
|
||||
Extra options
|
||||
For the extra options, refer to
|
||||
`Data Source Option <https://spark.apache.org/docs/latest/sql-data-sources-json.html#data-source-option>`_
|
||||
for the version you use.
|
||||
|
||||
.. # noqa
|
||||
|
||||
Examples:
|
||||
--------
|
||||
Write a DataFrame into a JSON file and read it back.
|
||||
|
||||
>>> import tempfile
|
||||
>>> with tempfile.TemporaryDirectory() as d:
|
||||
... # Write a DataFrame into a JSON file
|
||||
... spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}]).write.mode(
|
||||
... "overwrite"
|
||||
... ).format("json").save(d)
|
||||
...
|
||||
... # Read the JSON file as a DataFrame.
|
||||
... spark.read.json(d).show()
|
||||
+---+------------+
|
||||
|age| name|
|
||||
+---+------------+
|
||||
|100|Hyukjin Kwon|
|
||||
+---+------------+
|
||||
"""
|
||||
if schema is not None:
|
||||
msg = "The 'schema' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if primitivesAsString is not None:
|
||||
msg = "The 'primitivesAsString' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if prefersDecimal is not None:
|
||||
msg = "The 'prefersDecimal' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if allowComments is not None:
|
||||
msg = "The 'allowComments' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if allowUnquotedFieldNames is not None:
|
||||
msg = "The 'allowUnquotedFieldNames' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if allowSingleQuotes is not None:
|
||||
msg = "The 'allowSingleQuotes' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if allowNumericLeadingZero is not None:
|
||||
msg = "The 'allowNumericLeadingZero' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if allowBackslashEscapingAnyCharacter is not None:
|
||||
msg = "The 'allowBackslashEscapingAnyCharacter' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if mode is not None:
|
||||
msg = "The 'mode' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if columnNameOfCorruptRecord is not None:
|
||||
msg = "The 'columnNameOfCorruptRecord' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if dateFormat is not None:
|
||||
msg = "The 'dateFormat' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if timestampFormat is not None:
|
||||
msg = "The 'timestampFormat' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if multiLine is not None:
|
||||
msg = "The 'multiLine' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if allowUnquotedControlChars is not None:
|
||||
msg = "The 'allowUnquotedControlChars' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if lineSep is not None:
|
||||
msg = "The 'lineSep' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if samplingRatio is not None:
|
||||
msg = "The 'samplingRatio' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if dropFieldIfAllNull is not None:
|
||||
msg = "The 'dropFieldIfAllNull' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if encoding is not None:
|
||||
msg = "The 'encoding' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if locale is not None:
|
||||
msg = "The 'locale' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if pathGlobFilter is not None:
|
||||
msg = "The 'pathGlobFilter' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if recursiveFileLookup is not None:
|
||||
msg = "The 'recursiveFileLookup' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if modifiedBefore is not None:
|
||||
msg = "The 'modifiedBefore' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if modifiedAfter is not None:
|
||||
msg = "The 'modifiedAfter' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if allowNonNumericNumbers is not None:
|
||||
msg = "The 'allowNonNumericNumbers' option is not supported"
|
||||
raise ContributionsAcceptedError(msg)
|
||||
|
||||
if isinstance(path, str):
|
||||
path = [path]
|
||||
if isinstance(path, list):
|
||||
if len(path) == 1:
|
||||
rel = self.session.conn.read_json(path[0])
|
||||
from .dataframe import DataFrame
|
||||
|
||||
df = DataFrame(rel, self.session)
|
||||
return df
|
||||
raise PySparkNotImplementedError(message="Only a single path is supported for now")
|
||||
else:
|
||||
raise PySparkTypeError(
|
||||
error_class="NOT_STR_OR_LIST_OF_RDD",
|
||||
message_parameters={
|
||||
"arg_name": "path",
|
||||
"arg_type": type(path).__name__,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["DataFrameReader", "DataFrameWriter"]
|
||||
297
venv/Lib/site-packages/duckdb/experimental/spark/sql/session.py
Normal file
297
venv/Lib/site-packages/duckdb/experimental/spark/sql/session.py
Normal file
@@ -0,0 +1,297 @@
|
||||
import uuid # noqa: D100
|
||||
from collections.abc import Iterable, Sized
|
||||
from typing import TYPE_CHECKING, Any, NoReturn, Optional, Union
|
||||
|
||||
import duckdb
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pandas.core.frame import DataFrame as PandasDataFrame
|
||||
|
||||
from .catalog import Catalog
|
||||
|
||||
|
||||
from ..conf import SparkConf
|
||||
from ..context import SparkContext
|
||||
from ..errors import PySparkTypeError
|
||||
from ..exception import ContributionsAcceptedError
|
||||
from .conf import RuntimeConfig
|
||||
from .dataframe import DataFrame
|
||||
from .readwriter import DataFrameReader
|
||||
from .streaming import DataStreamReader
|
||||
from .types import StructType
|
||||
from .udf import UDFRegistration
|
||||
|
||||
# In spark:
|
||||
# SparkSession holds a SparkContext
|
||||
# SparkContext gets created from SparkConf
|
||||
# At this level the check is made to determine whether the instance already exists and just needs
|
||||
# to be retrieved or it needs to be created.
|
||||
|
||||
# For us this is done inside of `duckdb.connect`, based on the passed in path + configuration
|
||||
# SparkContext can be compared to our Connection class, and SparkConf to our ClientContext class
|
||||
|
||||
|
||||
# data is a List of rows
|
||||
# every value in each row needs to be turned into a Value
|
||||
def _combine_data_and_schema(data: Iterable[Any], schema: StructType) -> list[duckdb.Value]:
|
||||
from duckdb import Value
|
||||
|
||||
new_data = []
|
||||
for row in data:
|
||||
new_row = [Value(x, dtype.duckdb_type) for x, dtype in zip(row, [y.dataType for y in schema])]
|
||||
new_data.append(new_row)
|
||||
return new_data
|
||||
|
||||
|
||||
class SparkSession: # noqa: D101
|
||||
def __init__(self, context: SparkContext) -> None: # noqa: D107
|
||||
self.conn = context.connection
|
||||
self._context = context
|
||||
self._conf = RuntimeConfig(self.conn)
|
||||
|
||||
def _create_dataframe(self, data: Union[Iterable[Any], "PandasDataFrame"]) -> DataFrame:
|
||||
try:
|
||||
import pandas
|
||||
|
||||
has_pandas = True
|
||||
except ImportError:
|
||||
has_pandas = False
|
||||
if has_pandas and isinstance(data, pandas.DataFrame):
|
||||
unique_name = f"pyspark_pandas_df_{uuid.uuid1()}"
|
||||
self.conn.register(unique_name, data)
|
||||
return DataFrame(self.conn.sql(f'select * from "{unique_name}"'), self)
|
||||
|
||||
def verify_tuple_integrity(tuples: list[tuple]) -> None:
|
||||
if len(tuples) <= 1:
|
||||
return
|
||||
expected_length = len(tuples[0])
|
||||
for i, item in enumerate(tuples[1:]):
|
||||
actual_length = len(item)
|
||||
if expected_length == actual_length:
|
||||
continue
|
||||
raise PySparkTypeError(
|
||||
error_class="LENGTH_SHOULD_BE_THE_SAME",
|
||||
message_parameters={
|
||||
"arg1": f"data{i}",
|
||||
"arg2": f"data{i + 1}",
|
||||
"arg1_length": str(expected_length),
|
||||
"arg2_length": str(actual_length),
|
||||
},
|
||||
)
|
||||
|
||||
if not isinstance(data, list):
|
||||
data = list(data)
|
||||
verify_tuple_integrity(data)
|
||||
|
||||
def construct_query(tuples: Iterable) -> str:
|
||||
def construct_values_list(row: Sized, start_param_idx: int) -> str:
|
||||
parameter_count = len(row)
|
||||
parameters = [f"${x + start_param_idx}" for x in range(parameter_count)]
|
||||
parameters = "(" + ", ".join(parameters) + ")"
|
||||
return parameters
|
||||
|
||||
row_size = len(tuples[0])
|
||||
values_list = [construct_values_list(x, 1 + (i * row_size)) for i, x in enumerate(tuples)]
|
||||
values_list = ", ".join(values_list)
|
||||
|
||||
query = f"""
|
||||
select * from (values {values_list})
|
||||
"""
|
||||
return query
|
||||
|
||||
query = construct_query(data)
|
||||
|
||||
def construct_parameters(tuples: Iterable) -> list[list]:
|
||||
parameters = []
|
||||
for row in tuples:
|
||||
parameters.extend(list(row))
|
||||
return parameters
|
||||
|
||||
parameters = construct_parameters(data)
|
||||
|
||||
rel = self.conn.sql(query, params=parameters)
|
||||
return DataFrame(rel, self)
|
||||
|
||||
def _createDataFrameFromPandas(
|
||||
self, data: "PandasDataFrame", types: Union[list[str], None], names: Union[list[str], None]
|
||||
) -> DataFrame:
|
||||
df = self._create_dataframe(data)
|
||||
|
||||
# Cast to types
|
||||
if types:
|
||||
df = df._cast_types(*types)
|
||||
# Alias to names
|
||||
if names:
|
||||
df = df.toDF(*names)
|
||||
return df
|
||||
|
||||
def createDataFrame( # noqa: D102
|
||||
self,
|
||||
data: Union["PandasDataFrame", Iterable[Any]],
|
||||
schema: Optional[Union[StructType, list[str]]] = None,
|
||||
samplingRatio: Optional[float] = None,
|
||||
verifySchema: bool = True,
|
||||
) -> DataFrame:
|
||||
if samplingRatio:
|
||||
raise NotImplementedError
|
||||
if not verifySchema:
|
||||
raise NotImplementedError
|
||||
types = None
|
||||
names = None
|
||||
|
||||
if isinstance(data, DataFrame):
|
||||
raise PySparkTypeError(
|
||||
error_class="SHOULD_NOT_DATAFRAME",
|
||||
message_parameters={"arg_name": "data"},
|
||||
)
|
||||
|
||||
if schema:
|
||||
if isinstance(schema, StructType):
|
||||
types, names = schema.extract_types_and_names()
|
||||
else:
|
||||
names = schema
|
||||
|
||||
try:
|
||||
import pandas
|
||||
|
||||
has_pandas = True
|
||||
except ImportError:
|
||||
has_pandas = False
|
||||
# Falsey check on pandas dataframe is not defined, so first check if it's not a pandas dataframe
|
||||
# Then check if 'data' is None or []
|
||||
if has_pandas and isinstance(data, pandas.DataFrame):
|
||||
return self._createDataFrameFromPandas(data, types, names)
|
||||
|
||||
# Finally check if a schema was provided
|
||||
is_empty = False
|
||||
if not data and names:
|
||||
# Create NULLs for every type in our dataframe
|
||||
is_empty = True
|
||||
data = [tuple(None for _ in names)]
|
||||
|
||||
if schema and isinstance(schema, StructType):
|
||||
# Transform the data into Values to combine the data+schema
|
||||
data = _combine_data_and_schema(data, schema)
|
||||
|
||||
df = self._create_dataframe(data)
|
||||
if is_empty:
|
||||
rel = df.relation
|
||||
# Add impossible where clause
|
||||
rel = rel.filter("1=0")
|
||||
df = DataFrame(rel, self)
|
||||
|
||||
# Cast to types
|
||||
if types:
|
||||
df = df._cast_types(*types)
|
||||
# Alias to names
|
||||
if names:
|
||||
df = df.toDF(*names)
|
||||
return df
|
||||
|
||||
def newSession(self) -> "SparkSession": # noqa: D102
|
||||
return SparkSession(self._context)
|
||||
|
||||
def range( # noqa: D102
|
||||
self,
|
||||
start: int,
|
||||
end: Optional[int] = None,
|
||||
step: int = 1,
|
||||
numPartitions: Optional[int] = None,
|
||||
) -> "DataFrame":
|
||||
if numPartitions:
|
||||
raise ContributionsAcceptedError
|
||||
|
||||
if end is None:
|
||||
end = start
|
||||
start = 0
|
||||
|
||||
return DataFrame(self.conn.table_function("range", parameters=[start, end, step]), self)
|
||||
|
||||
def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame: # noqa: D102, ANN401
|
||||
if kwargs:
|
||||
raise NotImplementedError
|
||||
relation = self.conn.sql(sqlQuery)
|
||||
return DataFrame(relation, self)
|
||||
|
||||
def stop(self) -> None: # noqa: D102
|
||||
self._context.stop()
|
||||
|
||||
def table(self, tableName: str) -> DataFrame: # noqa: D102
|
||||
relation = self.conn.table(tableName)
|
||||
return DataFrame(relation, self)
|
||||
|
||||
def getActiveSession(self) -> "SparkSession": # noqa: D102
|
||||
return self
|
||||
|
||||
@property
|
||||
def catalog(self) -> "Catalog": # noqa: D102
|
||||
if not hasattr(self, "_catalog"):
|
||||
from duckdb.experimental.spark.sql.catalog import Catalog
|
||||
|
||||
self._catalog = Catalog(self)
|
||||
return self._catalog
|
||||
|
||||
@property
|
||||
def conf(self) -> RuntimeConfig: # noqa: D102
|
||||
return self._conf
|
||||
|
||||
@property
|
||||
def read(self) -> DataFrameReader: # noqa: D102
|
||||
return DataFrameReader(self)
|
||||
|
||||
@property
|
||||
def readStream(self) -> DataStreamReader: # noqa: D102
|
||||
return DataStreamReader(self)
|
||||
|
||||
@property
|
||||
def sparkContext(self) -> SparkContext: # noqa: D102
|
||||
return self._context
|
||||
|
||||
@property
|
||||
def streams(self) -> NoReturn: # noqa: D102
|
||||
raise ContributionsAcceptedError
|
||||
|
||||
@property
|
||||
def udf(self) -> UDFRegistration: # noqa: D102
|
||||
return UDFRegistration(self)
|
||||
|
||||
@property
|
||||
def version(self) -> str: # noqa: D102
|
||||
return "1.0.0"
|
||||
|
||||
class Builder: # noqa: D106
|
||||
def __init__(self) -> None: # noqa: D107
|
||||
pass
|
||||
|
||||
def master(self, name: str) -> "SparkSession.Builder": # noqa: D102
|
||||
# no-op
|
||||
return self
|
||||
|
||||
def appName(self, name: str) -> "SparkSession.Builder": # noqa: D102
|
||||
# no-op
|
||||
return self
|
||||
|
||||
def remote(self, url: str) -> "SparkSession.Builder": # noqa: D102
|
||||
# no-op
|
||||
return self
|
||||
|
||||
def getOrCreate(self) -> "SparkSession": # noqa: D102
|
||||
context = SparkContext("__ignored__")
|
||||
return SparkSession(context)
|
||||
|
||||
def config( # noqa: D102
|
||||
self,
|
||||
key: Optional[str] = None,
|
||||
value: Optional[Any] = None, # noqa: ANN401
|
||||
conf: Optional[SparkConf] = None,
|
||||
) -> "SparkSession.Builder":
|
||||
return self
|
||||
|
||||
def enableHiveSupport(self) -> "SparkSession.Builder": # noqa: D102
|
||||
# no-op
|
||||
return self
|
||||
|
||||
builder = Builder()
|
||||
|
||||
|
||||
__all__ = ["SparkSession"]
|
||||
@@ -0,0 +1,36 @@
|
||||
from typing import TYPE_CHECKING, Optional, Union # noqa: D100
|
||||
|
||||
from .types import StructType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .dataframe import DataFrame
|
||||
from .session import SparkSession
|
||||
|
||||
PrimitiveType = Union[bool, float, int, str]
|
||||
OptionalPrimitiveType = Optional[PrimitiveType]
|
||||
|
||||
|
||||
class DataStreamWriter: # noqa: D101
|
||||
def __init__(self, dataframe: "DataFrame") -> None: # noqa: D107
|
||||
self.dataframe = dataframe
|
||||
|
||||
def toTable(self, table_name: str) -> None: # noqa: D102
|
||||
# Should we register the dataframe or create a table from the contents?
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DataStreamReader: # noqa: D101
|
||||
def __init__(self, session: "SparkSession") -> None: # noqa: D107
|
||||
self.session = session
|
||||
|
||||
def load( # noqa: D102
|
||||
self,
|
||||
path: Optional[str] = None,
|
||||
format: Optional[str] = None,
|
||||
schema: Union[StructType, str, None] = None,
|
||||
**options: OptionalPrimitiveType,
|
||||
) -> "DataFrame":
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
__all__ = ["DataStreamReader", "DataStreamWriter"]
|
||||
@@ -0,0 +1,113 @@
|
||||
from typing import cast # noqa: D100
|
||||
|
||||
from duckdb.sqltypes import DuckDBPyType
|
||||
|
||||
from ..exception import ContributionsAcceptedError
|
||||
from .types import (
|
||||
ArrayType,
|
||||
BinaryType,
|
||||
BitstringType,
|
||||
BooleanType,
|
||||
ByteType,
|
||||
DataType,
|
||||
DateType,
|
||||
DayTimeIntervalType,
|
||||
DecimalType,
|
||||
DoubleType,
|
||||
FloatType,
|
||||
HugeIntegerType,
|
||||
IntegerType,
|
||||
LongType,
|
||||
MapType,
|
||||
ShortType,
|
||||
StringType,
|
||||
StructField,
|
||||
StructType,
|
||||
TimeNTZType,
|
||||
TimestampMilisecondNTZType,
|
||||
TimestampNanosecondNTZType,
|
||||
TimestampNTZType,
|
||||
TimestampSecondNTZType,
|
||||
TimestampType,
|
||||
TimeType,
|
||||
UnsignedByteType,
|
||||
UnsignedHugeIntegerType,
|
||||
UnsignedIntegerType,
|
||||
UnsignedLongType,
|
||||
UnsignedShortType,
|
||||
UUIDType,
|
||||
)
|
||||
|
||||
_sqltype_to_spark_class = {
|
||||
"boolean": BooleanType,
|
||||
"utinyint": UnsignedByteType,
|
||||
"tinyint": ByteType,
|
||||
"usmallint": UnsignedShortType,
|
||||
"smallint": ShortType,
|
||||
"uinteger": UnsignedIntegerType,
|
||||
"integer": IntegerType,
|
||||
"ubigint": UnsignedLongType,
|
||||
"bigint": LongType,
|
||||
"hugeint": HugeIntegerType,
|
||||
"uhugeint": UnsignedHugeIntegerType,
|
||||
"varchar": StringType,
|
||||
"blob": BinaryType,
|
||||
"bit": BitstringType,
|
||||
"uuid": UUIDType,
|
||||
"date": DateType,
|
||||
"time": TimeNTZType,
|
||||
"time with time zone": TimeType,
|
||||
"timestamp": TimestampNTZType,
|
||||
"timestamp with time zone": TimestampType,
|
||||
"timestamp_ms": TimestampNanosecondNTZType,
|
||||
"timestamp_ns": TimestampMilisecondNTZType,
|
||||
"timestamp_s": TimestampSecondNTZType,
|
||||
"interval": DayTimeIntervalType,
|
||||
"list": ArrayType,
|
||||
"struct": StructType,
|
||||
"map": MapType,
|
||||
# union
|
||||
# enum
|
||||
# null (???)
|
||||
"float": FloatType,
|
||||
"double": DoubleType,
|
||||
"decimal": DecimalType,
|
||||
}
|
||||
|
||||
|
||||
def convert_nested_type(dtype: DuckDBPyType) -> DataType: # noqa: D103
|
||||
id = dtype.id
|
||||
if id == "list" or id == "array":
|
||||
children = dtype.children
|
||||
return ArrayType(convert_type(children[0][1]))
|
||||
if id == "union":
|
||||
msg = (
|
||||
"Union types are not supported in the PySpark interface. "
|
||||
"DuckDB union types cannot be directly mapped to PySpark types."
|
||||
)
|
||||
raise ContributionsAcceptedError(msg)
|
||||
if id == "struct":
|
||||
children: list[tuple[str, DuckDBPyType]] = dtype.children
|
||||
fields = [StructField(x[0], convert_type(x[1])) for x in children]
|
||||
return StructType(fields)
|
||||
if id == "map":
|
||||
return MapType(convert_type(dtype.key), convert_type(dtype.value))
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def convert_type(dtype: DuckDBPyType) -> DataType: # noqa: D103
|
||||
id = dtype.id
|
||||
if id in ["list", "struct", "map", "array"]:
|
||||
return convert_nested_type(dtype)
|
||||
if id == "decimal":
|
||||
children: list[tuple[str, DuckDBPyType]] = dtype.children
|
||||
precision = cast("int", children[0][1])
|
||||
scale = cast("int", children[1][1])
|
||||
return DecimalType(precision, scale)
|
||||
spark_type = _sqltype_to_spark_class[id]
|
||||
return spark_type()
|
||||
|
||||
|
||||
def duckdb_to_spark_schema(names: list[str], types: list[DuckDBPyType]) -> StructType: # noqa: D103
|
||||
fields = [StructField(name, dtype) for name, dtype in zip(names, [convert_type(x) for x in types])]
|
||||
return StructType(fields)
|
||||
1310
venv/Lib/site-packages/duckdb/experimental/spark/sql/types.py
Normal file
1310
venv/Lib/site-packages/duckdb/experimental/spark/sql/types.py
Normal file
File diff suppressed because it is too large
Load Diff
37
venv/Lib/site-packages/duckdb/experimental/spark/sql/udf.py
Normal file
37
venv/Lib/site-packages/duckdb/experimental/spark/sql/udf.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# https://sparkbyexamples.com/pyspark/pyspark-udf-user-defined-function/ # noqa: D100
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
from .types import DataType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .session import SparkSession
|
||||
|
||||
DataTypeOrString = Union[DataType, str]
|
||||
UserDefinedFunctionLike = TypeVar("UserDefinedFunctionLike")
|
||||
|
||||
|
||||
class UDFRegistration: # noqa: D101
|
||||
def __init__(self, sparkSession: "SparkSession") -> None: # noqa: D107
|
||||
self.sparkSession = sparkSession
|
||||
|
||||
def register( # noqa: D102
|
||||
self,
|
||||
name: str,
|
||||
f: Union[Callable[..., Any], "UserDefinedFunctionLike"],
|
||||
returnType: Optional["DataTypeOrString"] = None,
|
||||
) -> "UserDefinedFunctionLike":
|
||||
self.sparkSession.conn.create_function(name, f, return_type=returnType)
|
||||
|
||||
def registerJavaFunction( # noqa: D102
|
||||
self,
|
||||
name: str,
|
||||
javaClassName: str,
|
||||
returnType: Optional["DataTypeOrString"] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def registerJavaUDAF(self, name: str, javaClassName: str) -> None: # noqa: D102
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
__all__ = ["UDFRegistration"]
|
||||
Reference in New Issue
Block a user