initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,5 @@
from . import spark # noqa: D104
__all__ = [
"spark",
]

View File

@@ -0,0 +1,260 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
------------------------------------------------------------------------------------
This product bundles various third-party components under other open source licenses.
This section summarizes those components and their licenses. See licenses/
for text of these licenses.
Apache Software Foundation License 2.0
--------------------------------------
common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java
core/src/main/java/org/apache/spark/util/collection/TimSort.java
core/src/main/resources/org/apache/spark/ui/static/bootstrap*
core/src/main/resources/org/apache/spark/ui/static/vis*
docs/js/vendor/bootstrap.js
connector/spark-ganglia-lgpl/src/main/java/com/codahale/metrics/ganglia/GangliaReporter.java
Python Software Foundation License
----------------------------------
python/docs/source/_static/copybutton.js
BSD 3-Clause
------------
python/lib/py4j-*-src.zip
python/pyspark/cloudpickle/*.py
python/pyspark/join.py
core/src/main/resources/org/apache/spark/ui/static/d3.min.js
The CSS style for the navigation sidebar of the documentation was originally
submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project
is distributed under the 3-Clause BSD license.
MIT License
-----------
core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js
core/src/main/resources/org/apache/spark/ui/static/*dataTables*
core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js
core/src/main/resources/org/apache/spark/ui/static/jquery*
core/src/main/resources/org/apache/spark/ui/static/sorttable.js
docs/js/vendor/anchor.min.js
docs/js/vendor/jquery*
docs/js/vendor/modernizer*
Creative Commons CC0 1.0 Universal Public Domain Dedication
-----------------------------------------------------------
(see LICENSE-CC0.txt)
data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg
data/mllib/images/kittens/54893.jpg
data/mllib/images/kittens/DP153539.jpg
data/mllib/images/kittens/DP802813.jpg
data/mllib/images/multi-channel/chr30.4.184.jpg

View File

@@ -0,0 +1,6 @@
from .conf import SparkConf # noqa: D104
from .context import SparkContext
from .exception import ContributionsAcceptedError
from .sql import DataFrame, SparkSession
__all__ = ["ContributionsAcceptedError", "DataFrame", "SparkConf", "SparkContext", "SparkSession"]

View File

@@ -0,0 +1,77 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Module defining global singleton classes.
This module raises a RuntimeError if an attempt to reload it is made. In that
way the identities of the classes defined here are fixed and will remain so
even if duckdb spark itself is reloaded. In particular, a function like the following
will still work correctly after duckdb spark is reloaded:
def foo(arg=pyducdkb.spark._NoValue):
if arg is pyducdkb.spark._NoValue:
...
See gh-7844 for a discussion of the reload problem that motivated this module.
Note that this approach is taken after from NumPy.
"""
__ALL__ = ["_NoValue"]
# Disallow reloading this module so as to preserve the identities of the
# classes defined here.
if "_is_loaded" in globals():
msg = "Reloading duckdb.experimental.spark._globals is not allowed"
raise RuntimeError(msg)
_is_loaded = True
class _NoValueType:
"""Special keyword value.
The instance of this class may be used as the default value assigned to a
deprecated keyword in order to check if it has been given a user defined
value.
This class was copied from NumPy.
"""
__instance = None
def __new__(cls) -> "_NoValueType":
# ensure that only one instance exists
if not cls.__instance:
cls.__instance = super().__new__(cls)
return cls.__instance
# Make the _NoValue instance falsey
def __nonzero__(self) -> bool:
return False
__bool__ = __nonzero__
# needed for python 2 to preserve identity through a pickle
def __reduce__(self) -> tuple[type, tuple]:
return (self.__class__, ())
def __repr__(self) -> str:
return "<no value>"
_NoValue = _NoValueType()

View File

@@ -0,0 +1,46 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections.abc import Iterable, Sized
from typing import Callable, TypeVar, Union
from numpy import float32, float64, int32, int64, ndarray
from typing_extensions import Literal, Protocol, Self
F = TypeVar("F", bound=Callable)
T_co = TypeVar("T_co", covariant=True)
PrimitiveType = Union[bool, float, int, str]
NonUDFType = Literal[0]
class SupportsIAdd(Protocol):
def __iadd__(self, other: "SupportsIAdd") -> Self: ...
class SupportsOrdering(Protocol):
def __lt__(self, other: "SupportsOrdering") -> bool: ...
class SizedIterable(Protocol, Sized, Iterable[T_co]): ...
S = TypeVar("S", bound=SupportsOrdering)
NumberOrArray = TypeVar("NumberOrArray", float, int, complex, int32, int64, float32, float64, ndarray)

View File

@@ -0,0 +1,46 @@
from typing import Optional # noqa: D100
from duckdb.experimental.spark.exception import ContributionsAcceptedError
class SparkConf: # noqa: D101
def __init__(self) -> None: # noqa: D107
raise NotImplementedError
def contains(self, key: str) -> bool: # noqa: D102
raise ContributionsAcceptedError
def get(self, key: str, defaultValue: Optional[str] = None) -> Optional[str]: # noqa: D102
raise ContributionsAcceptedError
def getAll(self) -> list[tuple[str, str]]: # noqa: D102
raise ContributionsAcceptedError
def set(self, key: str, value: str) -> "SparkConf": # noqa: D102
raise ContributionsAcceptedError
def setAll(self, pairs: list[tuple[str, str]]) -> "SparkConf": # noqa: D102
raise ContributionsAcceptedError
def setAppName(self, value: str) -> "SparkConf": # noqa: D102
raise ContributionsAcceptedError
def setExecutorEnv( # noqa: D102
self, key: Optional[str] = None, value: Optional[str] = None, pairs: Optional[list[tuple[str, str]]] = None
) -> "SparkConf":
raise ContributionsAcceptedError
def setIfMissing(self, key: str, value: str) -> "SparkConf": # noqa: D102
raise ContributionsAcceptedError
def setMaster(self, value: str) -> "SparkConf": # noqa: D102
raise ContributionsAcceptedError
def setSparkHome(self, value: str) -> "SparkConf": # noqa: D102
raise ContributionsAcceptedError
def toDebugString(self) -> str: # noqa: D102
raise ContributionsAcceptedError
__all__ = ["SparkConf"]

View File

@@ -0,0 +1,180 @@
from typing import Optional # noqa: D100
import duckdb
from duckdb import DuckDBPyConnection
from duckdb.experimental.spark.conf import SparkConf
from duckdb.experimental.spark.exception import ContributionsAcceptedError
class SparkContext: # noqa: D101
def __init__(self, master: str) -> None: # noqa: D107
self._connection = duckdb.connect(":memory:")
# This aligns the null ordering with Spark.
self._connection.execute("set default_null_order='nulls_first_on_asc_last_on_desc'")
@property
def connection(self) -> DuckDBPyConnection: # noqa: D102
return self._connection
def stop(self) -> None: # noqa: D102
self._connection.close()
@classmethod
def getOrCreate(cls, conf: Optional[SparkConf] = None) -> "SparkContext": # noqa: D102
raise ContributionsAcceptedError
@classmethod
def setSystemProperty(cls, key: str, value: str) -> None: # noqa: D102
raise ContributionsAcceptedError
@property
def applicationId(self) -> str: # noqa: D102
raise ContributionsAcceptedError
@property
def defaultMinPartitions(self) -> int: # noqa: D102
raise ContributionsAcceptedError
@property
def defaultParallelism(self) -> int: # noqa: D102
raise ContributionsAcceptedError
# @property
# def resources(self) -> Dict[str, ResourceInformation]:
# raise ContributionsAcceptedError
@property
def startTime(self) -> str: # noqa: D102
raise ContributionsAcceptedError
@property
def uiWebUrl(self) -> str: # noqa: D102
raise ContributionsAcceptedError
@property
def version(self) -> str: # noqa: D102
raise ContributionsAcceptedError
def __repr__(self) -> str: # noqa: D105
raise ContributionsAcceptedError
# def accumulator(self, value: ~T, accum_param: Optional[ForwardRef('AccumulatorParam[T]')] = None
# ) -> 'Accumulator[T]':
# pass
def addArchive(self, path: str) -> None: # noqa: D102
raise ContributionsAcceptedError
def addFile(self, path: str, recursive: bool = False) -> None: # noqa: D102
raise ContributionsAcceptedError
def addPyFile(self, path: str) -> None: # noqa: D102
raise ContributionsAcceptedError
# def binaryFiles(self, path: str, minPartitions: Optional[int] = None
# ) -> duckdb.experimental.spark.rdd.RDD[typing.Tuple[str, bytes]]:
# pass
# def binaryRecords(self, path: str, recordLength: int) -> duckdb.experimental.spark.rdd.RDD[bytes]:
# pass
# def broadcast(self, value: ~T) -> 'Broadcast[T]':
# pass
def cancelAllJobs(self) -> None: # noqa: D102
raise ContributionsAcceptedError
def cancelJobGroup(self, groupId: str) -> None: # noqa: D102
raise ContributionsAcceptedError
def dump_profiles(self, path: str) -> None: # noqa: D102
raise ContributionsAcceptedError
# def emptyRDD(self) -> duckdb.experimental.spark.rdd.RDD[typing.Any]:
# pass
def getCheckpointDir(self) -> Optional[str]: # noqa: D102
raise ContributionsAcceptedError
def getConf(self) -> SparkConf: # noqa: D102
raise ContributionsAcceptedError
def getLocalProperty(self, key: str) -> Optional[str]: # noqa: D102
raise ContributionsAcceptedError
# def hadoopFile(self, path: str, inputFormatClass: str, keyClass: str, valueClass: str,
# keyConverter: Optional[str] = None, valueConverter: Optional[str] = None,
# conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]:
# pass
# def hadoopRDD(self, inputFormatClass: str, keyClass: str, valueClass: str, keyConverter: Optional[str] = None,
# valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0
# ) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]:
# pass
# def newAPIHadoopFile(self, path: str, inputFormatClass: str, keyClass: str, valueClass: str,
# keyConverter: Optional[str] = None, valueConverter: Optional[str] = None,
# conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]:
# pass
# def newAPIHadoopRDD(self, inputFormatClass: str, keyClass: str, valueClass: str,
# keyConverter: Optional[str] = None, valueConverter: Optional[str] = None,
# conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]:
# pass
# def parallelize(self, c: Iterable[~T], numSlices: Optional[int] = None) -> pyspark.rdd.RDD[~T]:
# pass
# def pickleFile(self, name: str, minPartitions: Optional[int] = None) -> pyspark.rdd.RDD[typing.Any]:
# pass
# def range(self, start: int, end: Optional[int] = None, step: int = 1, numSlices: Optional[int] = None
# ) -> pyspark.rdd.RDD[int]:
# pass
# def runJob(self, rdd: pyspark.rdd.RDD[~T], partitionFunc: Callable[[Iterable[~T]], Iterable[~U]],
# partitions: Optional[Sequence[int]] = None, allowLocal: bool = False) -> List[~U]:
# pass
# def sequenceFile(self, path: str, keyClass: Optional[str] = None, valueClass: Optional[str] = None,
# keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, minSplits: Optional[int] = None,
# batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]:
# pass
def setCheckpointDir(self, dirName: str) -> None: # noqa: D102
raise ContributionsAcceptedError
def setJobDescription(self, value: str) -> None: # noqa: D102
raise ContributionsAcceptedError
def setJobGroup(self, groupId: str, description: str, interruptOnCancel: bool = False) -> None: # noqa: D102
raise ContributionsAcceptedError
def setLocalProperty(self, key: str, value: str) -> None: # noqa: D102
raise ContributionsAcceptedError
def setLogLevel(self, logLevel: str) -> None: # noqa: D102
raise ContributionsAcceptedError
def show_profiles(self) -> None: # noqa: D102
raise ContributionsAcceptedError
def sparkUser(self) -> str: # noqa: D102
raise ContributionsAcceptedError
# def statusTracker(self) -> duckdb.experimental.spark.status.StatusTracker:
# raise ContributionsAcceptedError
# def textFile(self, name: str, minPartitions: Optional[int] = None, use_unicode: bool = True
# ) -> pyspark.rdd.RDD[str]:
# pass
# def union(self, rdds: List[pyspark.rdd.RDD[~T]]) -> pyspark.rdd.RDD[~T]:
# pass
# def wholeTextFiles(self, path: str, minPartitions: Optional[int] = None, use_unicode: bool = True
# ) -> pyspark.rdd.RDD[typing.Tuple[str, str]]:
# pass
__all__ = ["SparkContext"]

View File

@@ -0,0 +1,70 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""PySpark exceptions."""
from .exceptions.base import (
AnalysisException,
ArithmeticException,
ArrayIndexOutOfBoundsException,
DateTimeException,
IllegalArgumentException,
NumberFormatException,
ParseException,
PySparkAssertionError,
PySparkAttributeError,
PySparkException,
PySparkIndexError,
PySparkNotImplementedError,
PySparkRuntimeError,
PySparkTypeError,
PySparkValueError,
PythonException,
QueryExecutionException,
SparkRuntimeException,
SparkUpgradeException,
StreamingQueryException,
TempTableAlreadyExistsException,
UnknownException,
UnsupportedOperationException,
)
__all__ = [
"AnalysisException",
"ArithmeticException",
"ArrayIndexOutOfBoundsException",
"DateTimeException",
"IllegalArgumentException",
"NumberFormatException",
"ParseException",
"PySparkAssertionError",
"PySparkAttributeError",
"PySparkException",
"PySparkIndexError",
"PySparkNotImplementedError",
"PySparkRuntimeError",
"PySparkTypeError",
"PySparkValueError",
"PythonException",
"QueryExecutionException",
"SparkRuntimeException",
"SparkUpgradeException",
"StreamingQueryException",
"TempTableAlreadyExistsException",
"UnknownException",
"UnsupportedOperationException",
]

View File

@@ -0,0 +1,918 @@
# ruff: noqa: D100, E501
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
ERROR_CLASSES_JSON = """
{
"APPLICATION_NAME_NOT_SET" : {
"message" : [
"An application name must be set in your configuration."
]
},
"ARGUMENT_REQUIRED": {
"message": [
"Argument `<arg_name>` is required when <condition>."
]
},
"ATTRIBUTE_NOT_CALLABLE" : {
"message" : [
"Attribute `<attr_name>` in provided object `<obj_name>` is not callable."
]
},
"ATTRIBUTE_NOT_SUPPORTED" : {
"message" : [
"Attribute `<attr_name>` is not supported."
]
},
"AXIS_LENGTH_MISMATCH" : {
"message" : [
"Length mismatch: Expected axis has <expected_length> element, new values have <actual_length> elements."
]
},
"BROADCAST_VARIABLE_NOT_LOADED": {
"message": [
"Broadcast variable `<variable>` not loaded."
]
},
"CALL_BEFORE_INITIALIZE": {
"message": [
"Not supported to call `<func_name>` before initialize <object>."
]
},
"CANNOT_ACCEPT_OBJECT_IN_TYPE": {
"message": [
"`<data_type>` can not accept object `<obj_name>` in type `<obj_type>`."
]
},
"CANNOT_ACCESS_TO_DUNDER": {
"message": [
"Dunder(double underscore) attribute is for internal use only."
]
},
"CANNOT_APPLY_IN_FOR_COLUMN": {
"message": [
"Cannot apply 'in' operator against a column: please use 'contains' in a string column or 'array_contains' function for an array column."
]
},
"CANNOT_BE_EMPTY": {
"message": [
"At least one <item> must be specified."
]
},
"CANNOT_BE_NONE": {
"message": [
"Argument `<arg_name>` can not be None."
]
},
"CANNOT_CONVERT_COLUMN_INTO_BOOL": {
"message": [
"Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions."
]
},
"CANNOT_CONVERT_TYPE": {
"message": [
"Cannot convert <from_type> into <to_type>."
]
},
"CANNOT_DETERMINE_TYPE": {
"message": [
"Some of types cannot be determined after inferring."
]
},
"CANNOT_GET_BATCH_ID": {
"message": [
"Could not get batch id from <obj_name>."
]
},
"CANNOT_INFER_ARRAY_TYPE": {
"message": [
"Can not infer Array Type from an list with None as the first element."
]
},
"CANNOT_INFER_EMPTY_SCHEMA": {
"message": [
"Can not infer schema from empty dataset."
]
},
"CANNOT_INFER_SCHEMA_FOR_TYPE": {
"message": [
"Can not infer schema for type: `<data_type>`."
]
},
"CANNOT_INFER_TYPE_FOR_FIELD": {
"message": [
"Unable to infer the type of the field `<field_name>`."
]
},
"CANNOT_MERGE_TYPE": {
"message": [
"Can not merge type `<data_type1>` and `<data_type2>`."
]
},
"CANNOT_OPEN_SOCKET": {
"message": [
"Can not open socket: <errors>."
]
},
"CANNOT_PARSE_DATATYPE": {
"message": [
"Unable to parse datatype. <msg>."
]
},
"CANNOT_PROVIDE_METADATA": {
"message": [
"metadata can only be provided for a single column."
]
},
"CANNOT_SET_TOGETHER": {
"message": [
"<arg_list> should not be set together."
]
},
"CANNOT_SPECIFY_RETURN_TYPE_FOR_UDF": {
"message": [
"returnType can not be specified when `<arg_name>` is a user-defined function, but got <return_type>."
]
},
"COLUMN_IN_LIST": {
"message": [
"`<func_name>` does not allow a Column in a list."
]
},
"CONTEXT_ONLY_VALID_ON_DRIVER" : {
"message" : [
"It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063."
]
},
"CONTEXT_UNAVAILABLE_FOR_REMOTE_CLIENT" : {
"message" : [
"Remote client cannot create a SparkContext. Create SparkSession instead."
]
},
"DIFFERENT_PANDAS_DATAFRAME" : {
"message" : [
"DataFrames are not almost equal:",
"Left:",
"<left>",
"<left_dtype>",
"Right:",
"<right>",
"<right_dtype>"
]
},
"DIFFERENT_PANDAS_INDEX" : {
"message" : [
"Indices are not almost equal:",
"Left:",
"<left>",
"<left_dtype>",
"Right:",
"<right>",
"<right_dtype>"
]
},
"DIFFERENT_PANDAS_MULTIINDEX" : {
"message" : [
"MultiIndices are not almost equal:",
"Left:",
"<left>",
"<left_dtype>",
"Right:",
"<right>",
"<right_dtype>"
]
},
"DIFFERENT_PANDAS_SERIES" : {
"message" : [
"Series are not almost equal:",
"Left:",
"<left>",
"<left_dtype>",
"Right:",
"<right>",
"<right_dtype>"
]
},
"DIFFERENT_ROWS" : {
"message" : [
"<error_msg>"
]
},
"DIFFERENT_SCHEMA" : {
"message" : [
"Schemas do not match.",
"--- actual",
"+++ expected",
"<error_msg>"
]
},
"DISALLOWED_TYPE_FOR_CONTAINER" : {
"message" : [
"Argument `<arg_name>`(type: <arg_type>) should only contain a type in [<allowed_types>], got <return_type>"
]
},
"DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT" : {
"message" : [
"Duplicated field names in Arrow Struct are not allowed, got <field_names>"
]
},
"HIGHER_ORDER_FUNCTION_SHOULD_RETURN_COLUMN" : {
"message" : [
"Function `<func_name>` should return Column, got <return_type>."
]
},
"INCORRECT_CONF_FOR_PROFILE" : {
"message" : [
"`spark.python.profile` or `spark.python.profile.memory` configuration",
" must be set to `true` to enable Python profile."
]
},
"INVALID_ARROW_UDTF_RETURN_TYPE" : {
"message" : [
"The return type of the arrow-optimized Python UDTF should be of type 'pandas.DataFrame', but the '<func>' method returned a value of type <type_name> with value: <value>."
]
},
"INVALID_BROADCAST_OPERATION": {
"message": [
"Broadcast can only be <operation> in driver."
]
},
"INVALID_CALL_ON_UNRESOLVED_OBJECT": {
"message": [
"Invalid call to `<func_name>` on unresolved object."
]
},
"INVALID_CONNECT_URL" : {
"message" : [
"Invalid URL for Spark Connect: <detail>"
]
},
"INVALID_ITEM_FOR_CONTAINER": {
"message": [
"All items in `<arg_name>` should be in <allowed_types>, got <item_type>."
]
},
"INVALID_NDARRAY_DIMENSION": {
"message": [
"NumPy array input should be of <dimensions> dimensions."
]
},
"INVALID_PANDAS_UDF" : {
"message" : [
"Invalid function: <detail>"
]
},
"INVALID_PANDAS_UDF_TYPE" : {
"message" : [
"`<arg_name>` should be one the values from PandasUDFType, got <arg_type>"
]
},
"INVALID_RETURN_TYPE_FOR_PANDAS_UDF": {
"message": [
"Pandas UDF should return StructType for <eval_type>, got <return_type>."
]
},
"INVALID_TIMEOUT_TIMESTAMP" : {
"message" : [
"Timeout timestamp (<timestamp>) cannot be earlier than the current watermark (<watermark>)."
]
},
"INVALID_TYPE" : {
"message" : [
"Argument `<arg_name>` should not be a <data_type>."
]
},
"INVALID_TYPENAME_CALL" : {
"message" : [
"StructField does not have typeName. Use typeName on its type explicitly instead."
]
},
"INVALID_TYPE_DF_EQUALITY_ARG" : {
"message" : [
"Expected type <expected_type> for `<arg_name>` but got type <actual_type>."
]
},
"INVALID_UDF_EVAL_TYPE" : {
"message" : [
"Eval type for UDF must be <eval_type>."
]
},
"INVALID_UDTF_BOTH_RETURN_TYPE_AND_ANALYZE" : {
"message" : [
"The UDTF '<name>' is invalid. It has both its return type and an 'analyze' attribute. Please make it have one of either the return type or the 'analyze' static method in '<name>' and try again."
]
},
"INVALID_UDTF_EVAL_TYPE" : {
"message" : [
"The eval type for the UDTF '<name>' is invalid. It must be one of <eval_type>."
]
},
"INVALID_UDTF_HANDLER_TYPE" : {
"message" : [
"The UDTF is invalid. The function handler must be a class, but got '<type>'. Please provide a class as the function handler."
]
},
"INVALID_UDTF_NO_EVAL" : {
"message" : [
"The UDTF '<name>' is invalid. It does not implement the required 'eval' method. Please implement the 'eval' method in '<name>' and try again."
]
},
"INVALID_UDTF_RETURN_TYPE" : {
"message" : [
"The UDTF '<name>' is invalid. It does not specify its return type or implement the required 'analyze' static method. Please specify the return type or implement the 'analyze' static method in '<name>' and try again."
]
},
"INVALID_WHEN_USAGE": {
"message": [
"when() can only be applied on a Column previously generated by when() function, and cannot be applied once otherwise() is applied."
]
},
"INVALID_WINDOW_BOUND_TYPE" : {
"message" : [
"Invalid window bound type: <window_bound_type>."
]
},
"JAVA_GATEWAY_EXITED" : {
"message" : [
"Java gateway process exited before sending its port number."
]
},
"JVM_ATTRIBUTE_NOT_SUPPORTED" : {
"message" : [
"Attribute `<attr_name>` is not supported in Spark Connect as it depends on the JVM. If you need to use this attribute, do not use Spark Connect when creating your session."
]
},
"KEY_VALUE_PAIR_REQUIRED" : {
"message" : [
"Key-value pair or a list of pairs is required."
]
},
"LENGTH_SHOULD_BE_THE_SAME" : {
"message" : [
"<arg1> and <arg2> should be of the same length, got <arg1_length> and <arg2_length>."
]
},
"MASTER_URL_NOT_SET" : {
"message" : [
"A master URL must be set in your configuration."
]
},
"MISSING_LIBRARY_FOR_PROFILER" : {
"message" : [
"Install the 'memory_profiler' library in the cluster to enable memory profiling."
]
},
"MISSING_VALID_PLAN" : {
"message" : [
"Argument to <operator> does not contain a valid plan."
]
},
"MIXED_TYPE_REPLACEMENT" : {
"message" : [
"Mixed type replacements are not supported."
]
},
"NEGATIVE_VALUE" : {
"message" : [
"Value for `<arg_name>` must be greater than or equal to 0, got '<arg_value>'."
]
},
"NOT_BOOL" : {
"message" : [
"Argument `<arg_name>` should be a bool, got <arg_type>."
]
},
"NOT_BOOL_OR_DICT_OR_FLOAT_OR_INT_OR_LIST_OR_STR_OR_TUPLE" : {
"message" : [
"Argument `<arg_name>` should be a bool, dict, float, int, str or tuple, got <arg_type>."
]
},
"NOT_BOOL_OR_DICT_OR_FLOAT_OR_INT_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a bool, dict, float, int or str, got <arg_type>."
]
},
"NOT_BOOL_OR_FLOAT_OR_INT" : {
"message" : [
"Argument `<arg_name>` should be a bool, float or str, got <arg_type>."
]
},
"NOT_BOOL_OR_FLOAT_OR_INT_OR_LIST_OR_NONE_OR_STR_OR_TUPLE" : {
"message" : [
"Argument `<arg_name>` should be a bool, float, int, list, None, str or tuple, got <arg_type>."
]
},
"NOT_BOOL_OR_FLOAT_OR_INT_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a bool, float, int or str, got <arg_type>."
]
},
"NOT_BOOL_OR_LIST" : {
"message" : [
"Argument `<arg_name>` should be a bool or list, got <arg_type>."
]
},
"NOT_BOOL_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a bool or str, got <arg_type>."
]
},
"NOT_CALLABLE" : {
"message" : [
"Argument `<arg_name>` should be a callable, got <arg_type>."
]
},
"NOT_COLUMN" : {
"message" : [
"Argument `<arg_name>` should be a Column, got <arg_type>."
]
},
"NOT_COLUMN_OR_DATATYPE_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a Column, str or DataType, but got <arg_type>."
]
},
"NOT_COLUMN_OR_FLOAT_OR_INT_OR_LIST_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a column, float, integer, list or string, got <arg_type>."
]
},
"NOT_COLUMN_OR_INT" : {
"message" : [
"Argument `<arg_name>` should be a Column or int, got <arg_type>."
]
},
"NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE" : {
"message" : [
"Argument `<arg_name>` should be a Column, int, list, str or tuple, got <arg_type>."
]
},
"NOT_COLUMN_OR_INT_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a Column, int or str, got <arg_type>."
]
},
"NOT_COLUMN_OR_LIST_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a Column, list or str, got <arg_type>."
]
},
"NOT_COLUMN_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a Column or str, got <arg_type>."
]
},
"NOT_COLUMN_OR_STR_OR_STRUCT" : {
"message" : [
"Argument `<arg_name>` should be a StructType, Column or str, got <arg_type>."
]
},
"NOT_DATAFRAME" : {
"message" : [
"Argument `<arg_name>` should be a DataFrame, got <arg_type>."
]
},
"NOT_DATATYPE_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a DataType or str, got <arg_type>."
]
},
"NOT_DICT" : {
"message" : [
"Argument `<arg_name>` should be a dict, got <arg_type>."
]
},
"NOT_EXPRESSION" : {
"message" : [
"Argument `<arg_name>` should be a Expression, got <arg_type>."
]
},
"NOT_FLOAT_OR_INT" : {
"message" : [
"Argument `<arg_name>` should be a float or int, got <arg_type>."
]
},
"NOT_FLOAT_OR_INT_OR_LIST_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a float, int, list or str, got <arg_type>."
]
},
"NOT_IMPLEMENTED" : {
"message" : [
"<feature> is not implemented."
]
},
"NOT_INSTANCE_OF" : {
"message" : [
"<value> is not an instance of type <data_type>."
]
},
"NOT_INT" : {
"message" : [
"Argument `<arg_name>` should be an int, got <arg_type>."
]
},
"NOT_INT_OR_SLICE_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be an int, slice or str, got <arg_type>."
]
},
"NOT_IN_BARRIER_STAGE" : {
"message" : [
"It is not in a barrier stage."
]
},
"NOT_ITERABLE" : {
"message" : [
"<objectName> is not iterable."
]
},
"NOT_LIST" : {
"message" : [
"Argument `<arg_name>` should be a list, got <arg_type>."
]
},
"NOT_LIST_OF_COLUMN" : {
"message" : [
"Argument `<arg_name>` should be a list[Column]."
]
},
"NOT_LIST_OF_COLUMN_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a list[Column]."
]
},
"NOT_LIST_OF_FLOAT_OR_INT" : {
"message" : [
"Argument `<arg_name>` should be a list[float, int], got <arg_type>."
]
},
"NOT_LIST_OF_STR" : {
"message" : [
"Argument `<arg_name>` should be a list[str], got <arg_type>."
]
},
"NOT_LIST_OR_NONE_OR_STRUCT" : {
"message" : [
"Argument `<arg_name>` should be a list, None or StructType, got <arg_type>."
]
},
"NOT_LIST_OR_STR_OR_TUPLE" : {
"message" : [
"Argument `<arg_name>` should be a list, str or tuple, got <arg_type>."
]
},
"NOT_LIST_OR_TUPLE" : {
"message" : [
"Argument `<arg_name>` should be a list or tuple, got <arg_type>."
]
},
"NOT_NUMERIC_COLUMNS" : {
"message" : [
"Numeric aggregation function can only be applied on numeric columns, got <invalid_columns>."
]
},
"NOT_OBSERVATION_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a Observation or str, got <arg_type>."
]
},
"NOT_SAME_TYPE" : {
"message" : [
"Argument `<arg_name1>` and `<arg_name2>` should be the same type, got <arg_type1> and <arg_type2>."
]
},
"NOT_STR" : {
"message" : [
"Argument `<arg_name>` should be a str, got <arg_type>."
]
},
"NOT_STR_OR_LIST_OF_RDD" : {
"message" : [
"Argument `<arg_name>` should be a str or list[RDD], got <arg_type>."
]
},
"NOT_STR_OR_STRUCT" : {
"message" : [
"Argument `<arg_name>` should be a str or structType, got <arg_type>."
]
},
"NOT_WINDOWSPEC" : {
"message" : [
"Argument `<arg_name>` should be a WindowSpec, got <arg_type>."
]
},
"NO_ACTIVE_OR_DEFAULT_SESSION" : {
"message" : [
"No active or default Spark session found. Please create a new Spark session before running the code."
]
},
"NO_ACTIVE_SESSION" : {
"message" : [
"No active Spark session found. Please create a new Spark session before running the code."
]
},
"ONLY_ALLOWED_FOR_SINGLE_COLUMN" : {
"message" : [
"Argument `<arg_name>` can only be provided for a single column."
]
},
"ONLY_ALLOW_SINGLE_TRIGGER" : {
"message" : [
"Only a single trigger is allowed."
]
},
"PIPE_FUNCTION_EXITED" : {
"message" : [
"Pipe function `<func_name>` exited with error code <error_code>."
]
},
"PYTHON_HASH_SEED_NOT_SET" : {
"message" : [
"Randomness of hash of string should be disabled via PYTHONHASHSEED."
]
},
"PYTHON_VERSION_MISMATCH" : {
"message" : [
"Python in worker has different version: <worker_version> than that in driver: <driver_version>, PySpark cannot run with different minor versions.",
"Please check environment variables PYSPARK_PYTHON and PYSPARK_DRIVER_PYTHON are correctly set."
]
},
"RDD_TRANSFORM_ONLY_VALID_ON_DRIVER" : {
"message" : [
"It appears that you are attempting to broadcast an RDD or reference an RDD from an ",
"action or transformation. RDD transformations and actions can only be invoked by the ",
"driver, not inside of other transformations; for example, ",
"rdd1.map(lambda x: rdd2.values.count() * x) is invalid because the values ",
"transformation and count action cannot be performed inside of the rdd1.map ",
"transformation. For more information, see SPARK-5063."
]
},
"RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF" : {
"message" : [
"Column names of the returned pandas.DataFrame do not match specified schema.<missing><extra>"
]
},
"RESULT_LENGTH_MISMATCH_FOR_PANDAS_UDF" : {
"message" : [
"Number of columns of the returned pandas.DataFrame doesn't match specified schema. Expected: <expected> Actual: <actual>"
]
},
"RESULT_LENGTH_MISMATCH_FOR_SCALAR_ITER_PANDAS_UDF" : {
"message" : [
"The length of output in Scalar iterator pandas UDF should be the same with the input's; however, the length of output was <output_length> and the length of input was <input_length>."
]
},
"SCHEMA_MISMATCH_FOR_PANDAS_UDF" : {
"message" : [
"Result vector from pandas_udf was not the required length: expected <expected>, got <actual>."
]
},
"SESSION_ALREADY_EXIST" : {
"message" : [
"Cannot start a remote Spark session because there is a regular Spark session already running."
]
},
"SESSION_NOT_SAME" : {
"message" : [
"Both Datasets must belong to the same SparkSession."
]
},
"SESSION_OR_CONTEXT_EXISTS" : {
"message" : [
"There should not be an existing Spark Session or Spark Context."
]
},
"SHOULD_NOT_DATAFRAME": {
"message": [
"Argument `<arg_name>` should not be a DataFrame."
]
},
"SLICE_WITH_STEP" : {
"message" : [
"Slice with step is not supported."
]
},
"STATE_NOT_EXISTS" : {
"message" : [
"State is either not defined or has already been removed."
]
},
"STOP_ITERATION_OCCURRED" : {
"message" : [
"Caught StopIteration thrown from user's code; failing the task: <exc>"
]
},
"STOP_ITERATION_OCCURRED_FROM_SCALAR_ITER_PANDAS_UDF" : {
"message" : [
"pandas iterator UDF should exhaust the input iterator."
]
},
"STREAMING_CONNECT_SERIALIZATION_ERROR" : {
"message" : [
"Cannot serialize the function `<name>`. If you accessed the Spark session, or a DataFrame defined outside of the function, or any object that contains a Spark session, please be aware that they are not allowed in Spark Connect. For `foreachBatch`, please access the Spark session using `df.sparkSession`, where `df` is the first parameter in your `foreachBatch` function. For `StreamingQueryListener`, please access the Spark session using `self.spark`. For details please check out the PySpark doc for `foreachBatch` and `StreamingQueryListener`."
]
},
"TOO_MANY_VALUES" : {
"message" : [
"Expected <expected> values for `<item>`, got <actual>."
]
},
"UDF_RETURN_TYPE" : {
"message" : [
"Return type of the user-defined function should be <expected>, but is <actual>."
]
},
"UDTF_ARROW_TYPE_CAST_ERROR" : {
"message" : [
"Cannot convert the output value of the column '<col_name>' with type '<col_type>' to the specified return type of the column: '<arrow_type>'. Please check if the data types match and try again."
]
},
"UDTF_EXEC_ERROR" : {
"message" : [
"User defined table function encountered an error in the '<method_name>' method: <error>"
]
},
"UDTF_INVALID_OUTPUT_ROW_TYPE" : {
"message" : [
"The type of an individual output row in the '<func>' method of the UDTF is invalid. Each row should be a tuple, list, or dict, but got '<type>'. Please make sure that the output rows are of the correct type."
]
},
"UDTF_RETURN_NOT_ITERABLE" : {
"message" : [
"The return value of the '<func>' method of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got '<type>'. Please make sure that the UDTF returns one of these types."
]
},
"UDTF_RETURN_SCHEMA_MISMATCH" : {
"message" : [
"The number of columns in the result does not match the specified schema. Expected column count: <expected>, Actual column count: <actual>. Please make sure the values returned by the '<func>' method have the same number of columns as specified in the output schema."
]
},
"UDTF_RETURN_TYPE_MISMATCH" : {
"message" : [
"Mismatch in return type for the UDTF '<name>'. Expected a 'StructType', but got '<return_type>'. Please ensure the return type is a correctly formatted StructType."
]
},
"UDTF_SERIALIZATION_ERROR" : {
"message" : [
"Cannot serialize the UDTF '<name>': <message>"
]
},
"UNEXPECTED_RESPONSE_FROM_SERVER" : {
"message" : [
"Unexpected response from iterator server."
]
},
"UNEXPECTED_TUPLE_WITH_STRUCT" : {
"message" : [
"Unexpected tuple <tuple> with StructType."
]
},
"UNKNOWN_EXPLAIN_MODE" : {
"message" : [
"Unknown explain mode: '<explain_mode>'. Accepted explain modes are 'simple', 'extended', 'codegen', 'cost', 'formatted'."
]
},
"UNKNOWN_INTERRUPT_TYPE" : {
"message" : [
"Unknown interrupt type: '<interrupt_type>'. Accepted interrupt types are 'all'."
]
},
"UNKNOWN_RESPONSE" : {
"message" : [
"Unknown response: <response>."
]
},
"UNSUPPORTED_DATA_TYPE" : {
"message" : [
"Unsupported DataType `<data_type>`."
]
},
"UNSUPPORTED_DATA_TYPE_FOR_ARROW" : {
"message" : [
"Single data type <data_type> is not supported with Arrow."
]
},
"UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION" : {
"message" : [
"<data_type> is not supported in conversion to Arrow."
]
},
"UNSUPPORTED_DATA_TYPE_FOR_ARROW_VERSION" : {
"message" : [
"<data_type> is only supported with pyarrow 2.0.0 and above."
]
},
"UNSUPPORTED_JOIN_TYPE" : {
"message" : [
"Unsupported join type: <join_type>. Supported join types include: \\"inner\\", \\"outer\\", \\"full\\", \\"fullouter\\", \\"full_outer\\", \\"leftouter\\", \\"left\\", \\"left_outer\\", \\"rightouter\\", \\"right\\", \\"right_outer\\", \\"leftsemi\\", \\"left_semi\\", \\"semi\\", \\"leftanti\\", \\"left_anti\\", \\"anti\\", \\"cross\\"."
]
},
"UNSUPPORTED_LITERAL" : {
"message" : [
"Unsupported Literal '<literal>'."
]
},
"UNSUPPORTED_NUMPY_ARRAY_SCALAR" : {
"message" : [
"The type of array scalar '<dtype>' is not supported."
]
},
"UNSUPPORTED_OPERATION" : {
"message" : [
"<operation> is not supported."
]
},
"UNSUPPORTED_PARAM_TYPE_FOR_HIGHER_ORDER_FUNCTION" : {
"message" : [
"Function `<func_name>` should use only POSITIONAL or POSITIONAL OR KEYWORD arguments."
]
},
"UNSUPPORTED_SIGNATURE" : {
"message" : [
"Unsupported signature: <signature>."
]
},
"UNSUPPORTED_WITH_ARROW_OPTIMIZATION" : {
"message" : [
"<feature> is not supported with Arrow optimization enabled in Python UDFs. Disable 'spark.sql.execution.pythonUDF.arrow.enabled' to workaround.."
]
},
"VALUE_NOT_ACCESSIBLE": {
"message": [
"Value `<value>` cannot be accessed inside tasks."
]
},
"VALUE_NOT_ALLOWED" : {
"message" : [
"Value for `<arg_name>` has to be amongst the following values: <allowed_values>."
]
},
"VALUE_NOT_ANY_OR_ALL" : {
"message" : [
"Value for `<arg_name>` must be 'any' or 'all', got '<arg_value>'."
]
},
"VALUE_NOT_BETWEEN" : {
"message" : [
"Value for `<arg_name>` must be between <min> and <max>."
]
},
"VALUE_NOT_NON_EMPTY_STR" : {
"message" : [
"Value for `<arg_name>` must be a non empty string, got '<arg_value>'."
]
},
"VALUE_NOT_PEARSON" : {
"message" : [
"Value for `<arg_name>` only supports the 'pearson', got '<arg_value>'."
]
},
"VALUE_NOT_POSITIVE" : {
"message" : [
"Value for `<arg_name>` must be positive, got '<arg_value>'."
]
},
"VALUE_NOT_TRUE" : {
"message" : [
"Value for `<arg_name>` must be True, got '<arg_value>'."
]
},
"VALUE_OUT_OF_BOUND" : {
"message" : [
"Value for `<arg_name>` must be greater than <lower_bound> or less than <upper_bound>, got <actual>"
]
},
"WRONG_NUM_ARGS_FOR_HIGHER_ORDER_FUNCTION" : {
"message" : [
"Function `<func_name>` should take between 1 and 3 arguments, but provided function takes <num_args>."
]
},
"WRONG_NUM_COLUMNS" : {
"message" : [
"Function `<func_name>` should take at least <num_cols> columns."
]
},
"ZERO_INDEX": {
"message": [
"Index must be non-zero."
]
}
}
"""
ERROR_CLASSES_MAP = json.loads(ERROR_CLASSES_JSON)

View File

@@ -0,0 +1,16 @@
# # noqa: D104
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

View File

@@ -0,0 +1,168 @@
from typing import Optional, cast # noqa: D100
from ..utils import ErrorClassesReader
class PySparkException(Exception):
"""Base Exception for handling errors generated from PySpark."""
def __init__( # noqa: D107
self,
message: Optional[str] = None,
# The error class, decides the message format, must be one of the valid options listed in 'error_classes.py'
error_class: Optional[str] = None,
# The dictionary listing the arguments specified in the message (or the error_class)
message_parameters: Optional[dict[str, str]] = None,
) -> None:
# `message` vs `error_class` & `message_parameters` are mutually exclusive.
assert (message is not None and (error_class is None and message_parameters is None)) or (
message is None and (error_class is not None and message_parameters is not None)
)
self.error_reader = ErrorClassesReader()
if message is None:
self.message = self.error_reader.get_error_message(
cast("str", error_class), cast("dict[str, str]", message_parameters)
)
else:
self.message = message
self.error_class = error_class
self.message_parameters = message_parameters
def getErrorClass(self) -> Optional[str]:
"""Returns an error class as a string.
.. versionadded:: 3.4.0
See Also:
--------
:meth:`PySparkException.getMessageParameters`
:meth:`PySparkException.getSqlState`
"""
return self.error_class
def getMessageParameters(self) -> Optional[dict[str, str]]:
"""Returns a message parameters as a dictionary.
.. versionadded:: 3.4.0
See Also:
--------
:meth:`PySparkException.getErrorClass`
:meth:`PySparkException.getSqlState`
"""
return self.message_parameters
def getSqlState(self) -> None:
"""Returns an SQLSTATE as a string.
Errors generated in Python have no SQLSTATE, so it always returns None.
.. versionadded:: 3.4.0
See Also:
--------
:meth:`PySparkException.getErrorClass`
:meth:`PySparkException.getMessageParameters`
"""
return None
def __str__(self) -> str: # noqa: D105
if self.getErrorClass() is not None:
return f"[{self.getErrorClass()}] {self.message}"
else:
return self.message
class AnalysisException(PySparkException):
"""Failed to analyze a SQL query plan."""
class SessionNotSameException(PySparkException):
"""Performed the same operation on different SparkSession."""
class TempTableAlreadyExistsException(AnalysisException):
"""Failed to create temp view since it is already exists."""
class ParseException(AnalysisException):
"""Failed to parse a SQL command."""
class IllegalArgumentException(PySparkException):
"""Passed an illegal or inappropriate argument."""
class ArithmeticException(PySparkException):
"""Arithmetic exception thrown from Spark with an error class."""
class UnsupportedOperationException(PySparkException):
"""Unsupported operation exception thrown from Spark with an error class."""
class ArrayIndexOutOfBoundsException(PySparkException):
"""Array index out of bounds exception thrown from Spark with an error class."""
class DateTimeException(PySparkException):
"""Datetime exception thrown from Spark with an error class."""
class NumberFormatException(IllegalArgumentException):
"""Number format exception thrown from Spark with an error class."""
class StreamingQueryException(PySparkException):
"""Exception that stopped a :class:`StreamingQuery`."""
class QueryExecutionException(PySparkException):
"""Failed to execute a query."""
class PythonException(PySparkException):
"""Exceptions thrown from Python workers."""
class SparkRuntimeException(PySparkException):
"""Runtime exception thrown from Spark with an error class."""
class SparkUpgradeException(PySparkException):
"""Exception thrown because of Spark upgrade."""
class UnknownException(PySparkException):
"""None of the above exceptions."""
class PySparkValueError(PySparkException, ValueError):
"""Wrapper class for ValueError to support error classes."""
class PySparkIndexError(PySparkException, IndexError):
"""Wrapper class for IndexError to support error classes."""
class PySparkTypeError(PySparkException, TypeError):
"""Wrapper class for TypeError to support error classes."""
class PySparkAttributeError(PySparkException, AttributeError):
"""Wrapper class for AttributeError to support error classes."""
class PySparkRuntimeError(PySparkException, RuntimeError):
"""Wrapper class for RuntimeError to support error classes."""
class PySparkAssertionError(PySparkException, AssertionError):
"""Wrapper class for AssertionError to support error classes."""
class PySparkNotImplementedError(PySparkException, NotImplementedError):
"""Wrapper class for NotImplementedError to support error classes."""

View File

@@ -0,0 +1,111 @@
# # noqa: D100
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from .error_classes import ERROR_CLASSES_MAP
class ErrorClassesReader:
"""A reader to load error information from error_classes.py."""
def __init__(self) -> None: # noqa: D107
self.error_info_map = ERROR_CLASSES_MAP
def get_error_message(self, error_class: str, message_parameters: dict[str, str]) -> str:
"""Returns the completed error message by applying message parameters to the message template."""
message_template = self.get_message_template(error_class)
# Verify message parameters.
message_parameters_from_template = re.findall("<([a-zA-Z0-9_-]+)>", message_template)
assert set(message_parameters_from_template) == set(message_parameters), (
f"Undefined error message parameter for error class: {error_class}. Parameters: {message_parameters}"
)
table = str.maketrans("<>", "{}")
return message_template.translate(table).format(**message_parameters)
def get_message_template(self, error_class: str) -> str:
"""Returns the message template for corresponding error class from error_classes.py.
For example,
when given `error_class` is "EXAMPLE_ERROR_CLASS",
and corresponding error class in error_classes.py looks like the below:
.. code-block:: python
"EXAMPLE_ERROR_CLASS" : {
"message" : [
"Problem <A> because of <B>."
]
}
In this case, this function returns:
"Problem <A> because of <B>."
For sub error class, when given `error_class` is "EXAMPLE_ERROR_CLASS.SUB_ERROR_CLASS",
and corresponding error class in error_classes.py looks like the below:
.. code-block:: python
"EXAMPLE_ERROR_CLASS" : {
"message" : [
"Problem <A> because of <B>."
],
"sub_class" : {
"SUB_ERROR_CLASS" : {
"message" : [
"Do <C> to fix the problem."
]
}
}
}
In this case, this function returns:
"Problem <A> because <B>. Do <C> to fix the problem."
"""
error_classes = error_class.split(".")
len_error_classes = len(error_classes)
assert len_error_classes in (1, 2)
# Generate message template for main error class.
main_error_class = error_classes[0]
if main_error_class in self.error_info_map:
main_error_class_info_map = self.error_info_map[main_error_class]
else:
msg = f"Cannot find main error class '{main_error_class}'"
raise ValueError(msg)
main_message_template = "\n".join(main_error_class_info_map["message"])
has_sub_class = len_error_classes == 2
if not has_sub_class:
message_template = main_message_template
else:
# Generate message template for sub error class if exists.
sub_error_class = error_classes[1]
main_error_class_subclass_info_map = main_error_class_info_map["sub_class"]
if sub_error_class in main_error_class_subclass_info_map:
sub_error_class_info_map = main_error_class_subclass_info_map[sub_error_class]
else:
msg = f"Cannot find sub error class '{sub_error_class}'"
raise ValueError(msg)
sub_message_template = "\n".join(sub_error_class_info_map["message"])
message_template = main_message_template + " " + sub_message_template
return message_template

View File

@@ -0,0 +1,18 @@
# ruff: noqa: D100
from typing import Optional
class ContributionsAcceptedError(NotImplementedError):
"""This method is not planned to be implemented, if you would like to implement this method
or show your interest in this method to other members of the community,
feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb.
""" # noqa: D205
def __init__(self, message: Optional[str] = None) -> None: # noqa: D107
doc = self.__class__.__doc__
if message:
doc = message + "\n" + doc
super().__init__(doc)
__all__ = ["ContributionsAcceptedError"]

View File

@@ -0,0 +1,7 @@
from .catalog import Catalog # noqa: D104
from .conf import RuntimeConfig
from .dataframe import DataFrame
from .readwriter import DataFrameWriter
from .session import SparkSession
__all__ = ["Catalog", "DataFrame", "DataFrameWriter", "RuntimeConfig", "SparkSession"]

View File

@@ -0,0 +1,86 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import (
Any,
Callable,
Optional,
TypeVar,
Union,
)
try:
from typing import Literal, Protocol
except ImportError:
from typing_extensions import Literal, Protocol
import datetime
import decimal
from .._typing import PrimitiveType
from . import types
from .column import Column
ColumnOrName = Union[Column, str]
ColumnOrName_ = TypeVar("ColumnOrName_", bound=ColumnOrName)
DecimalLiteral = decimal.Decimal
DateTimeLiteral = Union[datetime.datetime, datetime.date]
LiteralType = PrimitiveType
AtomicDataTypeOrString = Union[types.AtomicType, str]
DataTypeOrString = Union[types.DataType, str]
OptionalPrimitiveType = Optional[PrimitiveType]
AtomicValue = TypeVar(
"AtomicValue",
datetime.datetime,
datetime.date,
decimal.Decimal,
bool,
str,
int,
float,
)
RowLike = TypeVar("RowLike", list[Any], tuple[Any, ...], types.Row)
SQLBatchedUDFType = Literal[100]
class SupportsOpen(Protocol):
def open(self, partition_id: int, epoch_id: int) -> bool: ...
class SupportsProcess(Protocol):
def process(self, row: types.Row) -> None: ...
class SupportsClose(Protocol):
def close(self, error: Exception) -> None: ...
class UserDefinedFunctionLike(Protocol):
func: Callable[..., Any]
evalType: int
deterministic: bool
@property
def returnType(self) -> types.DataType: ...
def __call__(self, *args: ColumnOrName) -> Column: ...
def asNondeterministic(self) -> "UserDefinedFunctionLike": ...

View File

@@ -0,0 +1,79 @@
from typing import NamedTuple, Optional, Union # noqa: D100
from .session import SparkSession
class Database(NamedTuple): # noqa: D101
name: str
description: Optional[str]
locationUri: str
class Table(NamedTuple): # noqa: D101
name: str
database: Optional[str]
description: Optional[str]
tableType: str
isTemporary: bool
class Column(NamedTuple): # noqa: D101
name: str
description: Optional[str]
dataType: str
nullable: bool
isPartition: bool
isBucket: bool
class Function(NamedTuple): # noqa: D101
name: str
description: Optional[str]
className: str
isTemporary: bool
class Catalog: # noqa: D101
def __init__(self, session: SparkSession) -> None: # noqa: D107
self._session = session
def listDatabases(self) -> list[Database]: # noqa: D102
res = self._session.conn.sql("select database_name from duckdb_databases()").fetchall()
def transform_to_database(x: list[str]) -> Database:
return Database(name=x[0], description=None, locationUri="")
databases = [transform_to_database(x) for x in res]
return databases
def listTables(self) -> list[Table]: # noqa: D102
res = self._session.conn.sql("select table_name, database_name, sql, temporary from duckdb_tables()").fetchall()
def transform_to_table(x: list[str]) -> Table:
return Table(name=x[0], database=x[1], description=x[2], tableType="", isTemporary=x[3])
tables = [transform_to_table(x) for x in res]
return tables
def listColumns(self, tableName: str, dbName: Optional[str] = None) -> list[Column]: # noqa: D102
query = f"""
select column_name, data_type, is_nullable from duckdb_columns() where table_name = '{tableName}'
"""
if dbName:
query += f" and database_name = '{dbName}'"
res = self._session.conn.sql(query).fetchall()
def transform_to_column(x: list[Union[str, bool]]) -> Column:
return Column(name=x[0], description=None, dataType=x[1], nullable=x[2], isPartition=False, isBucket=False)
columns = [transform_to_column(x) for x in res]
return columns
def listFunctions(self, dbName: Optional[str] = None) -> list[Function]: # noqa: D102
raise NotImplementedError
def setCurrentDatabase(self, dbName: str) -> None: # noqa: D102
raise NotImplementedError
__all__ = ["Catalog", "Column", "Database", "Function", "Table"]

View File

@@ -0,0 +1,361 @@
from collections.abc import Iterable # noqa: D100
from typing import TYPE_CHECKING, Any, Callable, Union, cast
from ..exception import ContributionsAcceptedError
from .types import DataType
if TYPE_CHECKING:
from ._typing import DateTimeLiteral, DecimalLiteral, LiteralType
from duckdb import ColumnExpression, ConstantExpression, Expression, FunctionExpression
from duckdb.sqltypes import DuckDBPyType
__all__ = ["Column"]
def _get_expr(x: Union["Column", str]) -> Expression:
return x.expr if isinstance(x, Column) else ConstantExpression(x)
def _func_op(name: str, doc: str = "") -> Callable[["Column"], "Column"]:
def _(self: "Column") -> "Column":
njc = getattr(self.expr, name)()
return Column(njc)
_.__doc__ = doc
return _
def _unary_op(
name: str,
doc: str = "unary operator",
) -> Callable[["Column"], "Column"]:
"""Create a method for given unary operator."""
def _(self: "Column") -> "Column":
# Call the function identified by 'name' on the internal Expression object
expr = getattr(self.expr, name)()
return Column(expr)
_.__doc__ = doc
return _
def _bin_op(
name: str,
doc: str = "binary operator",
) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"]], "Column"]:
"""Create a method for given binary operator."""
def _(
self: "Column",
other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
) -> "Column":
jc = _get_expr(other)
njc = getattr(self.expr, name)(jc)
return Column(njc)
_.__doc__ = doc
return _
def _bin_func(
name: str,
doc: str = "binary function",
) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"]], "Column"]:
"""Create a function expression for the given binary function."""
def _(
self: "Column",
other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
) -> "Column":
other = _get_expr(other)
func = FunctionExpression(name, self.expr, other)
return Column(func)
_.__doc__ = doc
return _
class Column:
"""A column in a DataFrame.
:class:`Column` instances can be created by::
# 1. Select a column out of a DataFrame
df.colName
df["colName"]
# 2. Create from an expression
df.colName + 1
1 / df.colName
.. versionadded:: 1.3.0
"""
def __init__(self, expr: Expression) -> None: # noqa: D107
self.expr = expr
# arithmetic operators
def __neg__(self) -> "Column": # noqa: D105
return Column(-self.expr)
# `and`, `or`, `not` cannot be overloaded in Python,
# so use bitwise operators as boolean operators
__and__ = _bin_op("__and__")
__or__ = _bin_op("__or__")
__invert__ = _func_op("__invert__")
__rand__ = _bin_op("__rand__")
__ror__ = _bin_op("__ror__")
__add__ = _bin_op("__add__")
__sub__ = _bin_op("__sub__")
__mul__ = _bin_op("__mul__")
__div__ = _bin_op("__div__")
__truediv__ = _bin_op("__truediv__")
__mod__ = _bin_op("__mod__")
__pow__ = _bin_op("__pow__")
__radd__ = _bin_op("__radd__")
__rsub__ = _bin_op("__rsub__")
__rmul__ = _bin_op("__rmul__")
__rdiv__ = _bin_op("__rdiv__")
__rtruediv__ = _bin_op("__rtruediv__")
__rmod__ = _bin_op("__rmod__")
__rpow__ = _bin_op("__rpow__")
def __getitem__(self, k: Any) -> "Column": # noqa: ANN401
"""An expression that gets an item at position ``ordinal`` out of a list,
or gets an item by key out of a dict.
.. versionadded:: 1.3.0
.. versionchanged:: 3.4.0
Supports Spark Connect.
Parameters
----------
k
a literal value, or a slice object without step.
Returns:
-------
:class:`Column`
Column representing the item got by key out of a dict, or substrings sliced by
the given slice object.
Examples:
--------
>>> df = spark.createDataFrame([("abcedfg", {"key": "value"})], ["l", "d"])
>>> df.select(df.l[slice(1, 3)], df.d["key"]).show()
+------------------+------+
|substring(l, 1, 3)|d[key]|
+------------------+------+
| abc| value|
+------------------+------+
""" # noqa: D205
if isinstance(k, slice):
raise ContributionsAcceptedError
# if k.step is not None:
# raise ValueError("Using a slice with a step value is not supported")
# return self.substr(k.start, k.stop)
else:
# TODO: this is super hacky # noqa: TD002, TD003
expr_str = str(self.expr) + "." + str(k)
return Column(ColumnExpression(expr_str))
def __getattr__(self, item: Any) -> "Column": # noqa: ANN401
"""An expression that gets an item at position ``ordinal`` out of a list,
or gets an item by key out of a dict.
Parameters
----------
item
a literal value.
Returns:
-------
:class:`Column`
Column representing the item got by key out of a dict.
Examples:
--------
>>> df = spark.createDataFrame([("abcedfg", {"key": "value"})], ["l", "d"])
>>> df.select(df.d.key).show()
+------+
|d[key]|
+------+
| value|
+------+
""" # noqa: D205
if item.startswith("__"):
msg = "Can not access __ (dunder) method"
raise AttributeError(msg)
return self[item]
def alias(self, alias: str) -> "Column": # noqa: D102
return Column(self.expr.alias(alias))
def when(self, condition: "Column", value: Union["Column", str]) -> "Column": # noqa: D102
if not isinstance(condition, Column):
msg = "condition should be a Column"
raise TypeError(msg)
v = _get_expr(value)
expr = self.expr.when(condition.expr, v)
return Column(expr)
def otherwise(self, value: Union["Column", str]) -> "Column": # noqa: D102
v = _get_expr(value)
expr = self.expr.otherwise(v)
return Column(expr)
def cast(self, dataType: Union[DataType, str]) -> "Column": # noqa: D102
internal_type = DuckDBPyType(dataType) if isinstance(dataType, str) else dataType.duckdb_type
return Column(self.expr.cast(internal_type))
def isin(self, *cols: Union[Iterable[Union["Column", str]], Union["Column", str]]) -> "Column": # noqa: D102
if len(cols) == 1 and isinstance(cols[0], (list, set)):
# Only one argument supplied, it's a list
cols = cast("tuple", cols[0])
cols = cast(
"tuple",
[_get_expr(c) for c in cols],
)
return Column(self.expr.isin(*cols))
# logistic operators
def __eq__( # type: ignore[override]
self,
other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
) -> "Column":
"""Binary function."""
return Column(self.expr == (_get_expr(other)))
def __ne__( # type: ignore[override]
self,
other: object,
) -> "Column":
"""Binary function."""
return Column(self.expr != (_get_expr(other)))
__lt__ = _bin_op("__lt__")
__le__ = _bin_op("__le__")
__ge__ = _bin_op("__ge__")
__gt__ = _bin_op("__gt__")
# String interrogation methods
contains = _bin_func("contains")
rlike = _bin_func("regexp_matches")
like = _bin_func("~~")
ilike = _bin_func("~~*")
startswith = _bin_func("starts_with")
endswith = _bin_func("suffix")
# order
_asc_doc = """
Returns a sort expression based on the ascending order of the column.
Examples
--------
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.asc()).collect()
[Row(name='Alice'), Row(name='Tom')]
"""
_asc_nulls_first_doc = """
Returns a sort expression based on ascending order of the column, and null values
return before non-null values.
Examples
--------
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.asc_nulls_first()).collect()
[Row(name=None), Row(name='Alice'), Row(name='Tom')]
"""
_asc_nulls_last_doc = """
Returns a sort expression based on ascending order of the column, and null values
appear after non-null values.
Examples
--------
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.asc_nulls_last()).collect()
[Row(name='Alice'), Row(name='Tom'), Row(name=None)]
"""
_desc_doc = """
Returns a sort expression based on the descending order of the column.
Examples
--------
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.desc()).collect()
[Row(name='Tom'), Row(name='Alice')]
"""
_desc_nulls_first_doc = """
Returns a sort expression based on the descending order of the column, and null values
appear before non-null values.
Examples
--------
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.desc_nulls_first()).collect()
[Row(name=None), Row(name='Tom'), Row(name='Alice')]
"""
_desc_nulls_last_doc = """
Returns a sort expression based on the descending order of the column, and null values
appear after non-null values.
Examples
--------
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.desc_nulls_last()).collect()
[Row(name='Tom'), Row(name='Alice'), Row(name=None)]
"""
asc = _unary_op("asc", _asc_doc)
desc = _unary_op("desc", _desc_doc)
nulls_first = _unary_op("nulls_first")
nulls_last = _unary_op("nulls_last")
def asc_nulls_first(self) -> "Column": # noqa: D102
return self.asc().nulls_first()
def asc_nulls_last(self) -> "Column": # noqa: D102
return self.asc().nulls_last()
def desc_nulls_first(self) -> "Column": # noqa: D102
return self.desc().nulls_first()
def desc_nulls_last(self) -> "Column": # noqa: D102
return self.desc().nulls_last()
def isNull(self) -> "Column": # noqa: D102
return Column(self.expr.isnull())
def isNotNull(self) -> "Column": # noqa: D102
return Column(self.expr.isnotnull())

View File

@@ -0,0 +1,24 @@
from typing import Optional, Union # noqa: D100
from duckdb import DuckDBPyConnection
from duckdb.experimental.spark._globals import _NoValue, _NoValueType
class RuntimeConfig: # noqa: D101
def __init__(self, connection: DuckDBPyConnection) -> None: # noqa: D107
self._connection = connection
def set(self, key: str, value: str) -> None: # noqa: D102
raise NotImplementedError
def isModifiable(self, key: str) -> bool: # noqa: D102
raise NotImplementedError
def unset(self, key: str) -> None: # noqa: D102
raise NotImplementedError
def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue) -> str: # noqa: D102
raise NotImplementedError
__all__ = ["RuntimeConfig"]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,424 @@
# # noqa: D100
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import TYPE_CHECKING, Callable, Union, overload
from ..exception import ContributionsAcceptedError
from .column import Column
from .dataframe import DataFrame
from .functions import _to_column_expr
from .types import NumericType
# Only import symbols needed for type checking if something is type checking
if TYPE_CHECKING:
from ._typing import ColumnOrName
from .session import SparkSession
__all__ = ["GroupedData", "Grouping"]
def _api_internal(self: "GroupedData", name: str, *cols: str) -> DataFrame:
expressions = ",".join(list(cols))
group_by = str(self._grouping) if self._grouping else ""
projections = self._grouping.get_columns()
jdf = self._df.relation.apply(
function_name=name, # aggregate function
function_aggr=expressions, # inputs to aggregate
group_expr=group_by, # groups
projected_columns=projections, # projections
)
return DataFrame(jdf, self.session)
def df_varargs_api(f: Callable[..., DataFrame]) -> Callable[..., DataFrame]:
def _api(self: "GroupedData", *cols: str) -> DataFrame:
name = f.__name__
return _api_internal(self, name, *cols)
_api.__name__ = f.__name__
_api.__doc__ = f.__doc__
return _api
class Grouping: # noqa: D101
def __init__(self, *cols: "ColumnOrName", **kwargs) -> None: # noqa: D107
self._type = ""
self._cols = [_to_column_expr(x) for x in cols]
if "special" in kwargs:
special = kwargs["special"]
accepted_special = ["cube", "rollup"]
assert special in accepted_special
self._type = special
def get_columns(self) -> str: # noqa: D102
columns = ",".join([str(x) for x in self._cols])
return columns
def __str__(self) -> str: # noqa: D105
columns = self.get_columns()
if self._type:
return self._type + "(" + columns + ")"
return columns
class GroupedData:
"""A set of methods for aggregations on a :class:`DataFrame`,
created by :func:`DataFrame.groupBy`.
""" # noqa: D205
def __init__(self, grouping: Grouping, df: DataFrame) -> None: # noqa: D107
self._grouping = grouping
self._df = df
self.session: SparkSession = df.session
def __repr__(self) -> str: # noqa: D105
return str(self._df)
def count(self) -> DataFrame:
"""Counts the number of records for each group.
Examples:
--------
>>> df = spark.createDataFrame(
... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"]
... )
>>> df.show()
+---+-----+
|age| name|
+---+-----+
| 2|Alice|
| 3|Alice|
| 5| Bob|
| 10| Bob|
+---+-----+
Group-by name, and count each group.
>>> df.groupBy(df.name).count().sort("name").show()
+-----+-----+
| name|count|
+-----+-----+
|Alice| 2|
| Bob| 2|
+-----+-----+
"""
return _api_internal(self, "count").withColumnRenamed("count_star()", "count")
@df_varargs_api
def mean(self, *cols: str) -> DataFrame:
"""Computes average values for each numeric columns for each group.
:func:`mean` is an alias for :func:`avg`.
Parameters
----------
cols : str
column names. Non-numeric columns are ignored.
"""
def avg(self, *cols: str) -> DataFrame:
"""Computes average values for each numeric columns for each group.
:func:`mean` is an alias for :func:`avg`.
Parameters
----------
cols : str
column names. Non-numeric columns are ignored.
Examples:
--------
>>> df = spark.createDataFrame(
... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)],
... ["age", "name", "height"],
... )
>>> df.show()
+---+-----+------+
|age| name|height|
+---+-----+------+
| 2|Alice| 80|
| 3|Alice| 100|
| 5| Bob| 120|
| 10| Bob| 140|
+---+-----+------+
Group-by name, and calculate the mean of the age in each group.
>>> df.groupBy("name").avg("age").sort("name").show()
+-----+--------+
| name|avg(age)|
+-----+--------+
|Alice| 2.5|
| Bob| 7.5|
+-----+--------+
Calculate the mean of the age and height in all data.
>>> df.groupBy().avg("age", "height").show()
+--------+-----------+
|avg(age)|avg(height)|
+--------+-----------+
| 5.0| 110.0|
+--------+-----------+
"""
columns = list(cols)
if len(columns) == 0:
schema = self._df.schema
# Take only the numeric types of the relation
columns: list[str] = [x.name for x in schema.fields if isinstance(x.dataType, NumericType)]
return _api_internal(self, "avg", *columns)
@df_varargs_api
def max(self, *cols: str) -> DataFrame:
"""Computes the max value for each numeric columns for each group.
Examples:
--------
>>> df = spark.createDataFrame(
... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)],
... ["age", "name", "height"],
... )
>>> df.show()
+---+-----+------+
|age| name|height|
+---+-----+------+
| 2|Alice| 80|
| 3|Alice| 100|
| 5| Bob| 120|
| 10| Bob| 140|
+---+-----+------+
Group-by name, and calculate the max of the age in each group.
>>> df.groupBy("name").max("age").sort("name").show()
+-----+--------+
| name|max(age)|
+-----+--------+
|Alice| 3|
| Bob| 10|
+-----+--------+
Calculate the max of the age and height in all data.
>>> df.groupBy().max("age", "height").show()
+--------+-----------+
|max(age)|max(height)|
+--------+-----------+
| 10| 140|
+--------+-----------+
"""
@df_varargs_api
def min(self, *cols: str) -> DataFrame:
"""Computes the min value for each numeric column for each group.
Parameters
----------
cols : str
column names. Non-numeric columns are ignored.
Examples:
--------
>>> df = spark.createDataFrame(
... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)],
... ["age", "name", "height"],
... )
>>> df.show()
+---+-----+------+
|age| name|height|
+---+-----+------+
| 2|Alice| 80|
| 3|Alice| 100|
| 5| Bob| 120|
| 10| Bob| 140|
+---+-----+------+
Group-by name, and calculate the min of the age in each group.
>>> df.groupBy("name").min("age").sort("name").show()
+-----+--------+
| name|min(age)|
+-----+--------+
|Alice| 2|
| Bob| 5|
+-----+--------+
Calculate the min of the age and height in all data.
>>> df.groupBy().min("age", "height").show()
+--------+-----------+
|min(age)|min(height)|
+--------+-----------+
| 2| 80|
+--------+-----------+
"""
@df_varargs_api
def sum(self, *cols: str) -> DataFrame:
"""Computes the sum for each numeric columns for each group.
Parameters
----------
cols : str
column names. Non-numeric columns are ignored.
Examples:
--------
>>> df = spark.createDataFrame(
... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)],
... ["age", "name", "height"],
... )
>>> df.show()
+---+-----+------+
|age| name|height|
+---+-----+------+
| 2|Alice| 80|
| 3|Alice| 100|
| 5| Bob| 120|
| 10| Bob| 140|
+---+-----+------+
Group-by name, and calculate the sum of the age in each group.
>>> df.groupBy("name").sum("age").sort("name").show()
+-----+--------+
| name|sum(age)|
+-----+--------+
|Alice| 5|
| Bob| 15|
+-----+--------+
Calculate the sum of the age and height in all data.
>>> df.groupBy().sum("age", "height").show()
+--------+-----------+
|sum(age)|sum(height)|
+--------+-----------+
| 20| 440|
+--------+-----------+
"""
@overload
def agg(self, *exprs: Column) -> DataFrame: ...
@overload
def agg(self, __exprs: dict[str, str]) -> DataFrame: ... # noqa: PYI063
def agg(self, *exprs: Union[Column, dict[str, str]]) -> DataFrame:
"""Compute aggregates and returns the result as a :class:`DataFrame`.
The available aggregate functions can be:
1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count`
2. group aggregate pandas UDFs, created with :func:`pyspark.sql.functions.pandas_udf`
.. note:: There is no partial aggregation with group aggregate UDFs, i.e.,
a full shuffle is required. Also, all the data of a group will be loaded into
memory, so the user should be aware of the potential OOM risk if data is skewed
and certain groups are too large to fit in memory.
.. seealso:: :func:`pyspark.sql.functions.pandas_udf`
If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
is the column to perform aggregation on, and the value is the aggregate function.
Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.
.. versionadded:: 1.3.0
.. versionchanged:: 3.4.0
Supports Spark Connect.
Parameters
----------
exprs : dict
a dict mapping from column name (string) to aggregate functions (string),
or a list of :class:`Column`.
Notes:
-----
Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed
in a single call to this function.
Examples:
--------
>>> from pyspark.sql import functions as F
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"]
... )
>>> df.show()
+---+-----+
|age| name|
+---+-----+
| 2|Alice|
| 3|Alice|
| 5| Bob|
| 10| Bob|
+---+-----+
Group-by name, and count each group.
>>> df.groupBy(df.name)
GroupedData[grouping...: [name...], value: [age: bigint, name: string], type: GroupBy]
>>> df.groupBy(df.name).agg({"*": "count"}).sort("name").show()
+-----+--------+
| name|count(1)|
+-----+--------+
|Alice| 2|
| Bob| 2|
+-----+--------+
Group-by name, and calculate the minimum age.
>>> df.groupBy(df.name).agg(F.min(df.age)).sort("name").show()
+-----+--------+
| name|min(age)|
+-----+--------+
|Alice| 2|
| Bob| 5|
+-----+--------+
Same as above but uses pandas UDF.
>>> @pandas_udf("int", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
... def min_udf(v):
... return v.min()
>>> df.groupBy(df.name).agg(min_udf(df.age)).sort("name").show() # doctest: +SKIP
+-----+------------+
| name|min_udf(age)|
+-----+------------+
|Alice| 2|
| Bob| 5|
+-----+------------+
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
raise ContributionsAcceptedError
else:
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
expressions = list(self._grouping._cols)
expressions.extend([x.expr for x in exprs])
group_by = str(self._grouping)
rel = self._df.relation.select(*expressions, groups=group_by)
return DataFrame(rel, self.session)
# TODO: add 'pivot' # noqa: TD002, TD003

View File

@@ -0,0 +1,435 @@
from typing import TYPE_CHECKING, Optional, Union, cast # noqa: D100
from ..errors import PySparkNotImplementedError, PySparkTypeError
from ..exception import ContributionsAcceptedError
from .types import StructType
PrimitiveType = Union[bool, float, int, str]
OptionalPrimitiveType = Optional[PrimitiveType]
if TYPE_CHECKING:
from duckdb.experimental.spark.sql.dataframe import DataFrame
from duckdb.experimental.spark.sql.session import SparkSession
class DataFrameWriter: # noqa: D101
def __init__(self, dataframe: "DataFrame") -> None: # noqa: D107
self.dataframe = dataframe
def saveAsTable(self, table_name: str) -> None: # noqa: D102
relation = self.dataframe.relation
relation.create(table_name)
def parquet( # noqa: D102
self,
path: str,
mode: Optional[str] = None,
partitionBy: Union[str, list[str], None] = None,
compression: Optional[str] = None,
) -> None:
relation = self.dataframe.relation
if mode:
raise NotImplementedError
if partitionBy:
raise NotImplementedError
relation.write_parquet(path, compression=compression)
def csv( # noqa: D102
self,
path: str,
mode: Optional[str] = None,
compression: Optional[str] = None,
sep: Optional[str] = None,
quote: Optional[str] = None,
escape: Optional[str] = None,
header: Optional[Union[bool, str]] = None,
nullValue: Optional[str] = None,
escapeQuotes: Optional[Union[bool, str]] = None,
quoteAll: Optional[Union[bool, str]] = None,
dateFormat: Optional[str] = None,
timestampFormat: Optional[str] = None,
ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = None,
ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = None,
charToEscapeQuoteEscaping: Optional[str] = None,
encoding: Optional[str] = None,
emptyValue: Optional[str] = None,
lineSep: Optional[str] = None,
) -> None:
if mode not in (None, "overwrite"):
raise NotImplementedError
if escapeQuotes:
raise NotImplementedError
if ignoreLeadingWhiteSpace:
raise NotImplementedError
if ignoreTrailingWhiteSpace:
raise NotImplementedError
if charToEscapeQuoteEscaping:
raise NotImplementedError
if emptyValue:
raise NotImplementedError
if lineSep:
raise NotImplementedError
relation = self.dataframe.relation
relation.write_csv(
path,
sep=sep,
na_rep=nullValue,
quotechar=quote,
compression=compression,
escapechar=escape,
header=header if isinstance(header, bool) else header == "True",
encoding=encoding,
quoting=quoteAll,
date_format=dateFormat,
timestamp_format=timestampFormat,
)
class DataFrameReader: # noqa: D101
def __init__(self, session: "SparkSession") -> None: # noqa: D107
self.session = session
def load( # noqa: D102
self,
path: Optional[Union[str, list[str]]] = None,
format: Optional[str] = None,
schema: Optional[Union[StructType, str]] = None,
**options: OptionalPrimitiveType,
) -> "DataFrame":
from duckdb.experimental.spark.sql.dataframe import DataFrame
if not isinstance(path, str):
raise TypeError
if options:
raise ContributionsAcceptedError
rel = None
if format:
format = format.lower()
if format == "csv" or format == "tsv":
rel = self.session.conn.read_csv(path)
elif format == "json":
rel = self.session.conn.read_json(path)
elif format == "parquet":
rel = self.session.conn.read_parquet(path)
else:
raise ContributionsAcceptedError
else:
rel = self.session.conn.sql(f"select * from {path}")
df = DataFrame(rel, self.session)
if schema:
if not isinstance(schema, StructType):
raise ContributionsAcceptedError
schema = cast("StructType", schema)
types, names = schema.extract_types_and_names()
df = df._cast_types(types)
df = df.toDF(names)
return df
def csv( # noqa: D102
self,
path: Union[str, list[str]],
schema: Optional[Union[StructType, str]] = None,
sep: Optional[str] = None,
encoding: Optional[str] = None,
quote: Optional[str] = None,
escape: Optional[str] = None,
comment: Optional[str] = None,
header: Optional[Union[bool, str]] = None,
inferSchema: Optional[Union[bool, str]] = None,
ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = None,
ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = None,
nullValue: Optional[str] = None,
nanValue: Optional[str] = None,
positiveInf: Optional[str] = None,
negativeInf: Optional[str] = None,
dateFormat: Optional[str] = None,
timestampFormat: Optional[str] = None,
maxColumns: Optional[Union[int, str]] = None,
maxCharsPerColumn: Optional[Union[int, str]] = None,
maxMalformedLogPerPartition: Optional[Union[int, str]] = None,
mode: Optional[str] = None,
columnNameOfCorruptRecord: Optional[str] = None,
multiLine: Optional[Union[bool, str]] = None,
charToEscapeQuoteEscaping: Optional[str] = None,
samplingRatio: Optional[Union[float, str]] = None,
enforceSchema: Optional[Union[bool, str]] = None,
emptyValue: Optional[str] = None,
locale: Optional[str] = None,
lineSep: Optional[str] = None,
pathGlobFilter: Optional[Union[bool, str]] = None,
recursiveFileLookup: Optional[Union[bool, str]] = None,
modifiedBefore: Optional[Union[bool, str]] = None,
modifiedAfter: Optional[Union[bool, str]] = None,
unescapedQuoteHandling: Optional[str] = None,
) -> "DataFrame":
if not isinstance(path, str):
raise NotImplementedError
if schema and not isinstance(schema, StructType):
raise ContributionsAcceptedError
if comment:
raise ContributionsAcceptedError
if inferSchema:
raise ContributionsAcceptedError
if ignoreLeadingWhiteSpace:
raise ContributionsAcceptedError
if ignoreTrailingWhiteSpace:
raise ContributionsAcceptedError
if nanValue:
raise ConnectionAbortedError
if positiveInf:
raise ConnectionAbortedError
if negativeInf:
raise ConnectionAbortedError
if negativeInf:
raise ConnectionAbortedError
if maxColumns:
raise ContributionsAcceptedError
if maxCharsPerColumn:
raise ContributionsAcceptedError
if maxMalformedLogPerPartition:
raise ContributionsAcceptedError
if mode:
raise ContributionsAcceptedError
if columnNameOfCorruptRecord:
raise ContributionsAcceptedError
if multiLine:
raise ContributionsAcceptedError
if charToEscapeQuoteEscaping:
raise ContributionsAcceptedError
if samplingRatio:
raise ContributionsAcceptedError
if enforceSchema:
raise ContributionsAcceptedError
if emptyValue:
raise ContributionsAcceptedError
if locale:
raise ContributionsAcceptedError
if pathGlobFilter:
raise ContributionsAcceptedError
if recursiveFileLookup:
raise ContributionsAcceptedError
if modifiedBefore:
raise ContributionsAcceptedError
if modifiedAfter:
raise ContributionsAcceptedError
if unescapedQuoteHandling:
raise ContributionsAcceptedError
if lineSep:
# We have support for custom newline, just needs to be ported to 'read_csv'
raise NotImplementedError
dtype = None
names = None
if schema:
schema = cast("StructType", schema)
dtype, names = schema.extract_types_and_names()
rel = self.session.conn.read_csv(
path,
header=header if isinstance(header, bool) else header == "True",
sep=sep,
dtype=dtype,
na_values=nullValue,
quotechar=quote,
escapechar=escape,
encoding=encoding,
date_format=dateFormat,
timestamp_format=timestampFormat,
)
from ..sql.dataframe import DataFrame
df = DataFrame(rel, self.session)
if names:
df = df.toDF(*names)
return df
def parquet(self, *paths: str, **options: "OptionalPrimitiveType") -> "DataFrame": # noqa: D102
input = list(paths)
if len(input) != 1:
msg = "Only single paths are supported for now"
raise NotImplementedError(msg)
option_amount = len(options.keys())
if option_amount != 0:
msg = "Options are not supported"
raise ContributionsAcceptedError(msg)
path = input[0]
rel = self.session.conn.read_parquet(path)
from ..sql.dataframe import DataFrame
df = DataFrame(rel, self.session)
return df
def json(
self,
path: Union[str, list[str]],
schema: Optional[Union[StructType, str]] = None,
primitivesAsString: Optional[Union[bool, str]] = None,
prefersDecimal: Optional[Union[bool, str]] = None,
allowComments: Optional[Union[bool, str]] = None,
allowUnquotedFieldNames: Optional[Union[bool, str]] = None,
allowSingleQuotes: Optional[Union[bool, str]] = None,
allowNumericLeadingZero: Optional[Union[bool, str]] = None,
allowBackslashEscapingAnyCharacter: Optional[Union[bool, str]] = None,
mode: Optional[str] = None,
columnNameOfCorruptRecord: Optional[str] = None,
dateFormat: Optional[str] = None,
timestampFormat: Optional[str] = None,
multiLine: Optional[Union[bool, str]] = None,
allowUnquotedControlChars: Optional[Union[bool, str]] = None,
lineSep: Optional[str] = None,
samplingRatio: Optional[Union[float, str]] = None,
dropFieldIfAllNull: Optional[Union[bool, str]] = None,
encoding: Optional[str] = None,
locale: Optional[str] = None,
pathGlobFilter: Optional[Union[bool, str]] = None,
recursiveFileLookup: Optional[Union[bool, str]] = None,
modifiedBefore: Optional[Union[bool, str]] = None,
modifiedAfter: Optional[Union[bool, str]] = None,
allowNonNumericNumbers: Optional[Union[bool, str]] = None,
) -> "DataFrame":
"""Loads JSON files and returns the results as a :class:`DataFrame`.
`JSON Lines <http://jsonlines.org/>`_ (newline-delimited JSON) is supported by default.
For JSON (one record per file), set the ``multiLine`` parameter to ``true``.
If the ``schema`` parameter is not specified, this function goes
through the input once to determine the input schema.
.. versionadded:: 1.4.0
.. versionchanged:: 3.4.0
Supports Spark Connect.
Parameters
----------
path : str, list or :class:`RDD`
string represents path to the JSON dataset, or a list of paths,
or RDD of Strings storing JSON objects.
schema : :class:`pyspark.sql.types.StructType` or str, optional
an optional :class:`pyspark.sql.types.StructType` for the input schema or
a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
Other Parameters
----------------
Extra options
For the extra options, refer to
`Data Source Option <https://spark.apache.org/docs/latest/sql-data-sources-json.html#data-source-option>`_
for the version you use.
.. # noqa
Examples:
--------
Write a DataFrame into a JSON file and read it back.
>>> import tempfile
>>> with tempfile.TemporaryDirectory() as d:
... # Write a DataFrame into a JSON file
... spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}]).write.mode(
... "overwrite"
... ).format("json").save(d)
...
... # Read the JSON file as a DataFrame.
... spark.read.json(d).show()
+---+------------+
|age| name|
+---+------------+
|100|Hyukjin Kwon|
+---+------------+
"""
if schema is not None:
msg = "The 'schema' option is not supported"
raise ContributionsAcceptedError(msg)
if primitivesAsString is not None:
msg = "The 'primitivesAsString' option is not supported"
raise ContributionsAcceptedError(msg)
if prefersDecimal is not None:
msg = "The 'prefersDecimal' option is not supported"
raise ContributionsAcceptedError(msg)
if allowComments is not None:
msg = "The 'allowComments' option is not supported"
raise ContributionsAcceptedError(msg)
if allowUnquotedFieldNames is not None:
msg = "The 'allowUnquotedFieldNames' option is not supported"
raise ContributionsAcceptedError(msg)
if allowSingleQuotes is not None:
msg = "The 'allowSingleQuotes' option is not supported"
raise ContributionsAcceptedError(msg)
if allowNumericLeadingZero is not None:
msg = "The 'allowNumericLeadingZero' option is not supported"
raise ContributionsAcceptedError(msg)
if allowBackslashEscapingAnyCharacter is not None:
msg = "The 'allowBackslashEscapingAnyCharacter' option is not supported"
raise ContributionsAcceptedError(msg)
if mode is not None:
msg = "The 'mode' option is not supported"
raise ContributionsAcceptedError(msg)
if columnNameOfCorruptRecord is not None:
msg = "The 'columnNameOfCorruptRecord' option is not supported"
raise ContributionsAcceptedError(msg)
if dateFormat is not None:
msg = "The 'dateFormat' option is not supported"
raise ContributionsAcceptedError(msg)
if timestampFormat is not None:
msg = "The 'timestampFormat' option is not supported"
raise ContributionsAcceptedError(msg)
if multiLine is not None:
msg = "The 'multiLine' option is not supported"
raise ContributionsAcceptedError(msg)
if allowUnquotedControlChars is not None:
msg = "The 'allowUnquotedControlChars' option is not supported"
raise ContributionsAcceptedError(msg)
if lineSep is not None:
msg = "The 'lineSep' option is not supported"
raise ContributionsAcceptedError(msg)
if samplingRatio is not None:
msg = "The 'samplingRatio' option is not supported"
raise ContributionsAcceptedError(msg)
if dropFieldIfAllNull is not None:
msg = "The 'dropFieldIfAllNull' option is not supported"
raise ContributionsAcceptedError(msg)
if encoding is not None:
msg = "The 'encoding' option is not supported"
raise ContributionsAcceptedError(msg)
if locale is not None:
msg = "The 'locale' option is not supported"
raise ContributionsAcceptedError(msg)
if pathGlobFilter is not None:
msg = "The 'pathGlobFilter' option is not supported"
raise ContributionsAcceptedError(msg)
if recursiveFileLookup is not None:
msg = "The 'recursiveFileLookup' option is not supported"
raise ContributionsAcceptedError(msg)
if modifiedBefore is not None:
msg = "The 'modifiedBefore' option is not supported"
raise ContributionsAcceptedError(msg)
if modifiedAfter is not None:
msg = "The 'modifiedAfter' option is not supported"
raise ContributionsAcceptedError(msg)
if allowNonNumericNumbers is not None:
msg = "The 'allowNonNumericNumbers' option is not supported"
raise ContributionsAcceptedError(msg)
if isinstance(path, str):
path = [path]
if isinstance(path, list):
if len(path) == 1:
rel = self.session.conn.read_json(path[0])
from .dataframe import DataFrame
df = DataFrame(rel, self.session)
return df
raise PySparkNotImplementedError(message="Only a single path is supported for now")
else:
raise PySparkTypeError(
error_class="NOT_STR_OR_LIST_OF_RDD",
message_parameters={
"arg_name": "path",
"arg_type": type(path).__name__,
},
)
__all__ = ["DataFrameReader", "DataFrameWriter"]

View File

@@ -0,0 +1,297 @@
import uuid # noqa: D100
from collections.abc import Iterable, Sized
from typing import TYPE_CHECKING, Any, NoReturn, Optional, Union
import duckdb
if TYPE_CHECKING:
from pandas.core.frame import DataFrame as PandasDataFrame
from .catalog import Catalog
from ..conf import SparkConf
from ..context import SparkContext
from ..errors import PySparkTypeError
from ..exception import ContributionsAcceptedError
from .conf import RuntimeConfig
from .dataframe import DataFrame
from .readwriter import DataFrameReader
from .streaming import DataStreamReader
from .types import StructType
from .udf import UDFRegistration
# In spark:
# SparkSession holds a SparkContext
# SparkContext gets created from SparkConf
# At this level the check is made to determine whether the instance already exists and just needs
# to be retrieved or it needs to be created.
# For us this is done inside of `duckdb.connect`, based on the passed in path + configuration
# SparkContext can be compared to our Connection class, and SparkConf to our ClientContext class
# data is a List of rows
# every value in each row needs to be turned into a Value
def _combine_data_and_schema(data: Iterable[Any], schema: StructType) -> list[duckdb.Value]:
from duckdb import Value
new_data = []
for row in data:
new_row = [Value(x, dtype.duckdb_type) for x, dtype in zip(row, [y.dataType for y in schema])]
new_data.append(new_row)
return new_data
class SparkSession: # noqa: D101
def __init__(self, context: SparkContext) -> None: # noqa: D107
self.conn = context.connection
self._context = context
self._conf = RuntimeConfig(self.conn)
def _create_dataframe(self, data: Union[Iterable[Any], "PandasDataFrame"]) -> DataFrame:
try:
import pandas
has_pandas = True
except ImportError:
has_pandas = False
if has_pandas and isinstance(data, pandas.DataFrame):
unique_name = f"pyspark_pandas_df_{uuid.uuid1()}"
self.conn.register(unique_name, data)
return DataFrame(self.conn.sql(f'select * from "{unique_name}"'), self)
def verify_tuple_integrity(tuples: list[tuple]) -> None:
if len(tuples) <= 1:
return
expected_length = len(tuples[0])
for i, item in enumerate(tuples[1:]):
actual_length = len(item)
if expected_length == actual_length:
continue
raise PySparkTypeError(
error_class="LENGTH_SHOULD_BE_THE_SAME",
message_parameters={
"arg1": f"data{i}",
"arg2": f"data{i + 1}",
"arg1_length": str(expected_length),
"arg2_length": str(actual_length),
},
)
if not isinstance(data, list):
data = list(data)
verify_tuple_integrity(data)
def construct_query(tuples: Iterable) -> str:
def construct_values_list(row: Sized, start_param_idx: int) -> str:
parameter_count = len(row)
parameters = [f"${x + start_param_idx}" for x in range(parameter_count)]
parameters = "(" + ", ".join(parameters) + ")"
return parameters
row_size = len(tuples[0])
values_list = [construct_values_list(x, 1 + (i * row_size)) for i, x in enumerate(tuples)]
values_list = ", ".join(values_list)
query = f"""
select * from (values {values_list})
"""
return query
query = construct_query(data)
def construct_parameters(tuples: Iterable) -> list[list]:
parameters = []
for row in tuples:
parameters.extend(list(row))
return parameters
parameters = construct_parameters(data)
rel = self.conn.sql(query, params=parameters)
return DataFrame(rel, self)
def _createDataFrameFromPandas(
self, data: "PandasDataFrame", types: Union[list[str], None], names: Union[list[str], None]
) -> DataFrame:
df = self._create_dataframe(data)
# Cast to types
if types:
df = df._cast_types(*types)
# Alias to names
if names:
df = df.toDF(*names)
return df
def createDataFrame( # noqa: D102
self,
data: Union["PandasDataFrame", Iterable[Any]],
schema: Optional[Union[StructType, list[str]]] = None,
samplingRatio: Optional[float] = None,
verifySchema: bool = True,
) -> DataFrame:
if samplingRatio:
raise NotImplementedError
if not verifySchema:
raise NotImplementedError
types = None
names = None
if isinstance(data, DataFrame):
raise PySparkTypeError(
error_class="SHOULD_NOT_DATAFRAME",
message_parameters={"arg_name": "data"},
)
if schema:
if isinstance(schema, StructType):
types, names = schema.extract_types_and_names()
else:
names = schema
try:
import pandas
has_pandas = True
except ImportError:
has_pandas = False
# Falsey check on pandas dataframe is not defined, so first check if it's not a pandas dataframe
# Then check if 'data' is None or []
if has_pandas and isinstance(data, pandas.DataFrame):
return self._createDataFrameFromPandas(data, types, names)
# Finally check if a schema was provided
is_empty = False
if not data and names:
# Create NULLs for every type in our dataframe
is_empty = True
data = [tuple(None for _ in names)]
if schema and isinstance(schema, StructType):
# Transform the data into Values to combine the data+schema
data = _combine_data_and_schema(data, schema)
df = self._create_dataframe(data)
if is_empty:
rel = df.relation
# Add impossible where clause
rel = rel.filter("1=0")
df = DataFrame(rel, self)
# Cast to types
if types:
df = df._cast_types(*types)
# Alias to names
if names:
df = df.toDF(*names)
return df
def newSession(self) -> "SparkSession": # noqa: D102
return SparkSession(self._context)
def range( # noqa: D102
self,
start: int,
end: Optional[int] = None,
step: int = 1,
numPartitions: Optional[int] = None,
) -> "DataFrame":
if numPartitions:
raise ContributionsAcceptedError
if end is None:
end = start
start = 0
return DataFrame(self.conn.table_function("range", parameters=[start, end, step]), self)
def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame: # noqa: D102, ANN401
if kwargs:
raise NotImplementedError
relation = self.conn.sql(sqlQuery)
return DataFrame(relation, self)
def stop(self) -> None: # noqa: D102
self._context.stop()
def table(self, tableName: str) -> DataFrame: # noqa: D102
relation = self.conn.table(tableName)
return DataFrame(relation, self)
def getActiveSession(self) -> "SparkSession": # noqa: D102
return self
@property
def catalog(self) -> "Catalog": # noqa: D102
if not hasattr(self, "_catalog"):
from duckdb.experimental.spark.sql.catalog import Catalog
self._catalog = Catalog(self)
return self._catalog
@property
def conf(self) -> RuntimeConfig: # noqa: D102
return self._conf
@property
def read(self) -> DataFrameReader: # noqa: D102
return DataFrameReader(self)
@property
def readStream(self) -> DataStreamReader: # noqa: D102
return DataStreamReader(self)
@property
def sparkContext(self) -> SparkContext: # noqa: D102
return self._context
@property
def streams(self) -> NoReturn: # noqa: D102
raise ContributionsAcceptedError
@property
def udf(self) -> UDFRegistration: # noqa: D102
return UDFRegistration(self)
@property
def version(self) -> str: # noqa: D102
return "1.0.0"
class Builder: # noqa: D106
def __init__(self) -> None: # noqa: D107
pass
def master(self, name: str) -> "SparkSession.Builder": # noqa: D102
# no-op
return self
def appName(self, name: str) -> "SparkSession.Builder": # noqa: D102
# no-op
return self
def remote(self, url: str) -> "SparkSession.Builder": # noqa: D102
# no-op
return self
def getOrCreate(self) -> "SparkSession": # noqa: D102
context = SparkContext("__ignored__")
return SparkSession(context)
def config( # noqa: D102
self,
key: Optional[str] = None,
value: Optional[Any] = None, # noqa: ANN401
conf: Optional[SparkConf] = None,
) -> "SparkSession.Builder":
return self
def enableHiveSupport(self) -> "SparkSession.Builder": # noqa: D102
# no-op
return self
builder = Builder()
__all__ = ["SparkSession"]

View File

@@ -0,0 +1,36 @@
from typing import TYPE_CHECKING, Optional, Union # noqa: D100
from .types import StructType
if TYPE_CHECKING:
from .dataframe import DataFrame
from .session import SparkSession
PrimitiveType = Union[bool, float, int, str]
OptionalPrimitiveType = Optional[PrimitiveType]
class DataStreamWriter: # noqa: D101
def __init__(self, dataframe: "DataFrame") -> None: # noqa: D107
self.dataframe = dataframe
def toTable(self, table_name: str) -> None: # noqa: D102
# Should we register the dataframe or create a table from the contents?
raise NotImplementedError
class DataStreamReader: # noqa: D101
def __init__(self, session: "SparkSession") -> None: # noqa: D107
self.session = session
def load( # noqa: D102
self,
path: Optional[str] = None,
format: Optional[str] = None,
schema: Union[StructType, str, None] = None,
**options: OptionalPrimitiveType,
) -> "DataFrame":
raise NotImplementedError
__all__ = ["DataStreamReader", "DataStreamWriter"]

View File

@@ -0,0 +1,113 @@
from typing import cast # noqa: D100
from duckdb.sqltypes import DuckDBPyType
from ..exception import ContributionsAcceptedError
from .types import (
ArrayType,
BinaryType,
BitstringType,
BooleanType,
ByteType,
DataType,
DateType,
DayTimeIntervalType,
DecimalType,
DoubleType,
FloatType,
HugeIntegerType,
IntegerType,
LongType,
MapType,
ShortType,
StringType,
StructField,
StructType,
TimeNTZType,
TimestampMilisecondNTZType,
TimestampNanosecondNTZType,
TimestampNTZType,
TimestampSecondNTZType,
TimestampType,
TimeType,
UnsignedByteType,
UnsignedHugeIntegerType,
UnsignedIntegerType,
UnsignedLongType,
UnsignedShortType,
UUIDType,
)
_sqltype_to_spark_class = {
"boolean": BooleanType,
"utinyint": UnsignedByteType,
"tinyint": ByteType,
"usmallint": UnsignedShortType,
"smallint": ShortType,
"uinteger": UnsignedIntegerType,
"integer": IntegerType,
"ubigint": UnsignedLongType,
"bigint": LongType,
"hugeint": HugeIntegerType,
"uhugeint": UnsignedHugeIntegerType,
"varchar": StringType,
"blob": BinaryType,
"bit": BitstringType,
"uuid": UUIDType,
"date": DateType,
"time": TimeNTZType,
"time with time zone": TimeType,
"timestamp": TimestampNTZType,
"timestamp with time zone": TimestampType,
"timestamp_ms": TimestampNanosecondNTZType,
"timestamp_ns": TimestampMilisecondNTZType,
"timestamp_s": TimestampSecondNTZType,
"interval": DayTimeIntervalType,
"list": ArrayType,
"struct": StructType,
"map": MapType,
# union
# enum
# null (???)
"float": FloatType,
"double": DoubleType,
"decimal": DecimalType,
}
def convert_nested_type(dtype: DuckDBPyType) -> DataType: # noqa: D103
id = dtype.id
if id == "list" or id == "array":
children = dtype.children
return ArrayType(convert_type(children[0][1]))
if id == "union":
msg = (
"Union types are not supported in the PySpark interface. "
"DuckDB union types cannot be directly mapped to PySpark types."
)
raise ContributionsAcceptedError(msg)
if id == "struct":
children: list[tuple[str, DuckDBPyType]] = dtype.children
fields = [StructField(x[0], convert_type(x[1])) for x in children]
return StructType(fields)
if id == "map":
return MapType(convert_type(dtype.key), convert_type(dtype.value))
raise NotImplementedError
def convert_type(dtype: DuckDBPyType) -> DataType: # noqa: D103
id = dtype.id
if id in ["list", "struct", "map", "array"]:
return convert_nested_type(dtype)
if id == "decimal":
children: list[tuple[str, DuckDBPyType]] = dtype.children
precision = cast("int", children[0][1])
scale = cast("int", children[1][1])
return DecimalType(precision, scale)
spark_type = _sqltype_to_spark_class[id]
return spark_type()
def duckdb_to_spark_schema(names: list[str], types: list[DuckDBPyType]) -> StructType: # noqa: D103
fields = [StructField(name, dtype) for name, dtype in zip(names, [convert_type(x) for x in types])]
return StructType(fields)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,37 @@
# https://sparkbyexamples.com/pyspark/pyspark-udf-user-defined-function/ # noqa: D100
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
from .types import DataType
if TYPE_CHECKING:
from .session import SparkSession
DataTypeOrString = Union[DataType, str]
UserDefinedFunctionLike = TypeVar("UserDefinedFunctionLike")
class UDFRegistration: # noqa: D101
def __init__(self, sparkSession: "SparkSession") -> None: # noqa: D107
self.sparkSession = sparkSession
def register( # noqa: D102
self,
name: str,
f: Union[Callable[..., Any], "UserDefinedFunctionLike"],
returnType: Optional["DataTypeOrString"] = None,
) -> "UserDefinedFunctionLike":
self.sparkSession.conn.create_function(name, f, return_type=returnType)
def registerJavaFunction( # noqa: D102
self,
name: str,
javaClassName: str,
returnType: Optional["DataTypeOrString"] = None,
) -> None:
raise NotImplementedError
def registerJavaUDAF(self, name: str, javaClassName: str) -> None: # noqa: D102
raise NotImplementedError
__all__ = ["UDFRegistration"]