initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,381 @@
# ruff: noqa: F401
"""The DuckDB Python Package.
This module re-exports the DuckDB C++ extension (`_duckdb`) and provides DuckDB's public API.
Note:
- Some symbols exposed here are implementation details of DuckDB's C++ engine.
- They are kept for backwards compatibility but are not considered stable API.
- Future versions may move them into submodules with deprecation warnings.
"""
from _duckdb import (
BinderException,
CaseExpression,
CatalogException,
CoalesceOperator,
ColumnExpression,
ConnectionException,
ConstantExpression,
ConstraintException,
ConversionException,
CSVLineTerminator,
DatabaseError,
DataError,
DefaultExpression,
DependencyException,
DuckDBPyConnection,
DuckDBPyRelation,
Error,
ExpectedResultType,
ExplainType,
Expression,
FatalException,
FunctionExpression,
HTTPException,
IntegrityError,
InternalError,
InternalException,
InterruptException,
InvalidInputException,
InvalidTypeException,
IOException,
LambdaExpression,
NotImplementedException,
NotSupportedError,
OperationalError,
OutOfMemoryException,
OutOfRangeException,
ParserException,
PermissionException,
ProgrammingError,
PythonExceptionHandling,
RenderMode,
SequenceException,
SerializationException,
SQLExpression,
StarExpression,
Statement,
StatementType,
SyntaxException,
TransactionException,
TypeMismatchException,
Warning,
__formatted_python_version__,
__git_revision__,
__interactive__,
__jupyter__,
__standard_vector_size__,
_clean_default_connection,
aggregate,
alias,
apilevel,
append,
array_type,
arrow,
begin,
checkpoint,
close,
commit,
connect,
create_function,
cursor,
decimal_type,
default_connection,
description,
df,
distinct,
dtype,
duplicate,
enum_type,
execute,
executemany,
extract_statements,
fetch_arrow_table,
fetch_df,
fetch_df_chunk,
fetch_record_batch,
fetchall,
fetchdf,
fetchmany,
fetchnumpy,
fetchone,
filesystem_is_registered,
filter,
from_arrow,
from_csv_auto,
from_df,
from_parquet,
from_query,
get_table_names,
install_extension,
interrupt,
limit,
list_filesystems,
list_type,
load_extension,
map_type,
order,
paramstyle,
pl,
project,
query,
query_df,
query_progress,
read_csv,
read_json,
read_parquet,
register,
register_filesystem,
remove_function,
rollback,
row_type,
rowcount,
set_default_connection,
sql,
sqltype,
string_type,
struct_type,
table,
table_function,
tf,
threadsafety,
token_type,
tokenize,
torch,
type,
union_type,
unregister,
unregister_filesystem,
values,
view,
write_csv,
)
from duckdb._dbapi_type_object import (
BINARY,
DATETIME,
NUMBER,
ROWID,
STRING,
DBAPITypeObject,
)
from duckdb._version import (
__duckdb_version__,
__version__,
version,
)
from duckdb.value.constant import (
BinaryValue,
BitValue,
BlobValue,
BooleanValue,
DateValue,
DecimalValue,
DoubleValue,
FloatValue,
HugeIntegerValue,
IntegerValue,
IntervalValue,
ListValue,
LongValue,
MapValue,
NullValue,
ShortValue,
StringValue,
StructValue,
TimestampMilisecondValue,
TimestampNanosecondValue,
TimestampSecondValue,
TimestampTimeZoneValue,
TimestampValue,
TimeTimeZoneValue,
TimeValue,
UnionType,
UnsignedBinaryValue,
UnsignedHugeIntegerValue,
UnsignedIntegerValue,
UnsignedLongValue,
UnsignedShortValue,
UUIDValue,
Value,
)
__all__: list[str] = [
"BinaryValue",
"BinderException",
"BitValue",
"BlobValue",
"BooleanValue",
"CSVLineTerminator",
"CaseExpression",
"CatalogException",
"CoalesceOperator",
"ColumnExpression",
"ConnectionException",
"ConstantExpression",
"ConstraintException",
"ConversionException",
"DataError",
"DatabaseError",
"DateValue",
"DecimalValue",
"DefaultExpression",
"DependencyException",
"DoubleValue",
"DuckDBPyConnection",
"DuckDBPyRelation",
"Error",
"ExpectedResultType",
"ExplainType",
"Expression",
"FatalException",
"FloatValue",
"FunctionExpression",
"HTTPException",
"HugeIntegerValue",
"IOException",
"IntegerValue",
"IntegrityError",
"InternalError",
"InternalException",
"InterruptException",
"IntervalValue",
"InvalidInputException",
"InvalidTypeException",
"LambdaExpression",
"ListValue",
"LongValue",
"MapValue",
"NotImplementedException",
"NotSupportedError",
"NullValue",
"OperationalError",
"OutOfMemoryException",
"OutOfRangeException",
"ParserException",
"PermissionException",
"ProgrammingError",
"PythonExceptionHandling",
"RenderMode",
"SQLExpression",
"SequenceException",
"SerializationException",
"ShortValue",
"StarExpression",
"Statement",
"StatementType",
"StringValue",
"StructValue",
"SyntaxException",
"TimeTimeZoneValue",
"TimeValue",
"TimestampMilisecondValue",
"TimestampNanosecondValue",
"TimestampSecondValue",
"TimestampTimeZoneValue",
"TimestampValue",
"TransactionException",
"TypeMismatchException",
"UUIDValue",
"UnionType",
"UnsignedBinaryValue",
"UnsignedHugeIntegerValue",
"UnsignedIntegerValue",
"UnsignedLongValue",
"UnsignedShortValue",
"Value",
"Warning",
"__formatted_python_version__",
"__git_revision__",
"__interactive__",
"__jupyter__",
"__standard_vector_size__",
"__version__",
"_clean_default_connection",
"aggregate",
"alias",
"apilevel",
"append",
"array_type",
"arrow",
"begin",
"checkpoint",
"close",
"commit",
"connect",
"create_function",
"cursor",
"decimal_type",
"default_connection",
"description",
"df",
"distinct",
"dtype",
"duplicate",
"enum_type",
"execute",
"executemany",
"extract_statements",
"fetch_arrow_table",
"fetch_df",
"fetch_df_chunk",
"fetch_record_batch",
"fetchall",
"fetchdf",
"fetchmany",
"fetchnumpy",
"fetchone",
"filesystem_is_registered",
"filter",
"from_arrow",
"from_csv_auto",
"from_df",
"from_parquet",
"from_query",
"get_table_names",
"install_extension",
"interrupt",
"limit",
"list_filesystems",
"list_type",
"load_extension",
"map_type",
"order",
"paramstyle",
"paramstyle",
"pl",
"project",
"query",
"query_df",
"query_progress",
"read_csv",
"read_json",
"read_parquet",
"register",
"register_filesystem",
"remove_function",
"rollback",
"row_type",
"rowcount",
"set_default_connection",
"sql",
"sqltype",
"string_type",
"struct_type",
"table",
"table_function",
"tf",
"threadsafety",
"threadsafety",
"token_type",
"tokenize",
"torch",
"type",
"union_type",
"unregister",
"unregister_filesystem",
"values",
"view",
"write_csv",
]

View File

@@ -0,0 +1,231 @@
"""DuckDB DB API 2.0 Type Objects Module.
This module provides DB API 2.0 compliant type objects for DuckDB, allowing applications
to check column types returned by queries against standard database API categories.
Example:
>>> import duckdb
>>>
>>> conn = duckdb.connect()
>>> cursor = conn.cursor()
>>> cursor.execute("SELECT 'hello' as text_col, 42 as num_col, CURRENT_DATE as date_col")
>>>
>>> # Check column types using DB API type objects
>>> for i, desc in enumerate(cursor.description):
>>> col_name, col_type = desc[0], desc[1]
>>> if col_type == duckdb.STRING:
>>> print(f"{col_name} is a string type")
>>> elif col_type == duckdb.NUMBER:
>>> print(f"{col_name} is a numeric type")
>>> elif col_type == duckdb.DATETIME:
>>> print(f"{col_name} is a date/time type")
See Also:
- PEP 249: https://peps.python.org/pep-0249/
- DuckDB Type System: https://duckdb.org/docs/sql/data_types/overview
"""
from duckdb import sqltypes
class DBAPITypeObject:
"""DB API 2.0 type object for categorizing database column types.
This class implements the type objects defined in PEP 249 (DB API 2.0).
It allows checking whether a specific DuckDB type belongs to a broader
category like STRING, NUMBER, DATETIME, etc.
The type object supports equality comparison with DuckDBPyType instances,
returning True if the type belongs to this category.
Args:
types: A list of DuckDBPyType instances that belong to this type category.
Example:
>>> string_types = DBAPITypeObject([sqltypes.VARCHAR, sqltypes.CHAR])
>>> result = sqltypes.VARCHAR == string_types # True
>>> result = sqltypes.INTEGER == string_types # False
Note:
This follows the DB API 2.0 specification where type objects are compared
using equality operators rather than isinstance() checks.
"""
def __init__(self, types: list[sqltypes.DuckDBPyType]) -> None:
"""Initialize a DB API type object.
Args:
types: List of DuckDB types that belong to this category.
"""
self.types = types
def __eq__(self, other: object) -> bool:
"""Check if a DuckDB type belongs to this type category.
This method implements the DB API 2.0 type checking mechanism.
It returns True if the other object is a DuckDBPyType that
is contained in this type category.
Args:
other: The object to compare, typically a DuckDBPyType instance.
Returns:
True if other is a DuckDBPyType in this category, False otherwise.
Example:
>>> NUMBER == sqltypes.INTEGER # True
>>> NUMBER == sqltypes.VARCHAR # False
"""
if isinstance(other, sqltypes.DuckDBPyType):
return other in self.types
return False
def __repr__(self) -> str:
"""Return a string representation of this type object.
Returns:
A string showing the type object and its contained DuckDB types.
Example:
>>> repr(STRING)
'<DBAPITypeObject [VARCHAR]>'
"""
return f"<DBAPITypeObject [{','.join(str(x) for x in self.types)}]>"
# Define the standard DB API 2.0 type objects for DuckDB
STRING = DBAPITypeObject([sqltypes.VARCHAR])
"""
STRING type object for text-based database columns.
This type object represents all string/text types in DuckDB. Currently includes:
- VARCHAR: Variable-length character strings
Use this to check if a column contains textual data that should be handled
as Python strings.
DB API 2.0 Reference:
https://peps.python.org/pep-0249/#string
Example:
>>> cursor.description[0][1] == STRING # Check if first column is text
"""
NUMBER = DBAPITypeObject(
[
sqltypes.TINYINT,
sqltypes.UTINYINT,
sqltypes.SMALLINT,
sqltypes.USMALLINT,
sqltypes.INTEGER,
sqltypes.UINTEGER,
sqltypes.BIGINT,
sqltypes.UBIGINT,
sqltypes.HUGEINT,
sqltypes.UHUGEINT,
sqltypes.DuckDBPyType("BIGNUM"),
sqltypes.DuckDBPyType("DECIMAL"),
sqltypes.FLOAT,
sqltypes.DOUBLE,
]
)
"""
NUMBER type object for numeric database columns.
This type object represents all numeric types in DuckDB, including:
Integer Types:
- TINYINT, UTINYINT: 8-bit signed/unsigned integers
- SMALLINT, USMALLINT: 16-bit signed/unsigned integers
- INTEGER, UINTEGER: 32-bit signed/unsigned integers
- BIGINT, UBIGINT: 64-bit signed/unsigned integers
- HUGEINT, UHUGEINT: 128-bit signed/unsigned integers
Decimal Types:
- BIGNUM: Arbitrary precision integers
- DECIMAL: Fixed-point decimal numbers
Floating Point Types:
- FLOAT: 32-bit floating point
- DOUBLE: 64-bit floating point
Use this to check if a column contains numeric data that should be handled
as Python int, float, or Decimal objects.
DB API 2.0 Reference:
https://peps.python.org/pep-0249/#number
Example:
>>> cursor.description[1][1] == NUMBER # Check if second column is numeric
"""
DATETIME = DBAPITypeObject(
[
sqltypes.DATE,
sqltypes.TIME,
sqltypes.TIME_TZ,
sqltypes.TIMESTAMP,
sqltypes.TIMESTAMP_TZ,
sqltypes.TIMESTAMP_NS,
sqltypes.TIMESTAMP_MS,
sqltypes.TIMESTAMP_S,
]
)
"""
DATETIME type object for date and time database columns.
This type object represents all date/time types in DuckDB, including:
Date Types:
- DATE: Calendar dates (year, month, day)
Time Types:
- TIME: Time of day without timezone
- TIME_TZ: Time of day with timezone
Timestamp Types:
- TIMESTAMP: Date and time without timezone (microsecond precision)
- TIMESTAMP_TZ: Date and time with timezone
- TIMESTAMP_NS: Nanosecond precision timestamps
- TIMESTAMP_MS: Millisecond precision timestamps
- TIMESTAMP_S: Second precision timestamps
Use this to check if a column contains temporal data that should be handled
as Python datetime, date, or time objects.
DB API 2.0 Reference:
https://peps.python.org/pep-0249/#datetime
Example:
>>> cursor.description[2][1] == DATETIME # Check if third column is date/time
"""
BINARY = DBAPITypeObject([sqltypes.BLOB])
"""
BINARY type object for binary data database columns.
This type object represents binary data types in DuckDB:
- BLOB: Binary Large Objects for storing arbitrary binary data
Use this to check if a column contains binary data that should be handled
as Python bytes objects.
DB API 2.0 Reference:
https://peps.python.org/pep-0249/#binary
Example:
>>> cursor.description[3][1] == BINARY # Check if fourth column is binary
"""
ROWID = None
"""
ROWID type object for row identifier columns.
DB API 2.0 Reference:
https://peps.python.org/pep-0249/#rowid
Note:
This will always be None for DuckDB connections. Applications should not
rely on ROWID functionality when using DuckDB.
"""

View File

@@ -0,0 +1,22 @@
# ----------------------------------------------------------------------
# Version API
#
# We provide three symbols:
# - duckdb.__version__: The version of this package
# - duckdb.__duckdb_version__: The version of duckdb that is bundled
# - duckdb.version(): A human-readable version string containing both of the above
# ----------------------------------------------------------------------
from importlib.metadata import version as _dist_version
import _duckdb
__version__: str = _dist_version("duckdb")
"""Version of the DuckDB Python Package."""
__duckdb_version__: str = _duckdb.__version__
"""Version of DuckDB that is bundled."""
def version() -> str:
"""Human-friendly formatted version string of both the distribution package and the bundled DuckDB engine."""
return f"{__version__} (with duckdb {_duckdb.__version__})"

View File

@@ -0,0 +1,69 @@
"""StringIO buffer wrapper.
BSD 3-Clause License
Copyright (c) 2008-2011, AQR Capital Management, LLC, Lambda Foundry, Inc. and PyData Development Team
All rights reserved.
Copyright (c) 2011-2022, Open source contributors.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
from io import StringIO, TextIOBase
from typing import Any, Union
class BytesIOWrapper:
"""Wrapper that wraps a StringIO buffer and reads bytes from it.
Created for compat with pyarrow read_csv.
"""
def __init__(self, buffer: Union[StringIO, TextIOBase], encoding: str = "utf-8") -> None: # noqa: D107
self.buffer = buffer
self.encoding = encoding
# Because a character can be represented by more than 1 byte,
# it is possible that reading will produce more bytes than n
# We store the extra bytes in this overflow variable, and append the
# overflow to the front of the bytestring the next time reading is performed
self.overflow = b""
def __getattr__(self, attr: str) -> Any: # noqa: D105, ANN401
return getattr(self.buffer, attr)
def read(self, n: Union[int, None] = -1) -> bytes: # noqa: D102
assert self.buffer is not None
bytestring = self.buffer.read(n).encode(self.encoding)
# When n=-1/n greater than remaining bytes: Read entire file/rest of file
combined_bytestring = self.overflow + bytestring
if n is None or n < 0 or n >= len(combined_bytestring):
self.overflow = b""
return combined_bytestring
else:
to_return = combined_bytestring[:n]
self.overflow = combined_bytestring[n:]
return to_return

View File

@@ -0,0 +1,5 @@
from . import spark # noqa: D104
__all__ = [
"spark",
]

View File

@@ -0,0 +1,260 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
------------------------------------------------------------------------------------
This product bundles various third-party components under other open source licenses.
This section summarizes those components and their licenses. See licenses/
for text of these licenses.
Apache Software Foundation License 2.0
--------------------------------------
common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java
core/src/main/java/org/apache/spark/util/collection/TimSort.java
core/src/main/resources/org/apache/spark/ui/static/bootstrap*
core/src/main/resources/org/apache/spark/ui/static/vis*
docs/js/vendor/bootstrap.js
connector/spark-ganglia-lgpl/src/main/java/com/codahale/metrics/ganglia/GangliaReporter.java
Python Software Foundation License
----------------------------------
python/docs/source/_static/copybutton.js
BSD 3-Clause
------------
python/lib/py4j-*-src.zip
python/pyspark/cloudpickle/*.py
python/pyspark/join.py
core/src/main/resources/org/apache/spark/ui/static/d3.min.js
The CSS style for the navigation sidebar of the documentation was originally
submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project
is distributed under the 3-Clause BSD license.
MIT License
-----------
core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js
core/src/main/resources/org/apache/spark/ui/static/*dataTables*
core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js
core/src/main/resources/org/apache/spark/ui/static/jquery*
core/src/main/resources/org/apache/spark/ui/static/sorttable.js
docs/js/vendor/anchor.min.js
docs/js/vendor/jquery*
docs/js/vendor/modernizer*
Creative Commons CC0 1.0 Universal Public Domain Dedication
-----------------------------------------------------------
(see LICENSE-CC0.txt)
data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg
data/mllib/images/kittens/54893.jpg
data/mllib/images/kittens/DP153539.jpg
data/mllib/images/kittens/DP802813.jpg
data/mllib/images/multi-channel/chr30.4.184.jpg

View File

@@ -0,0 +1,6 @@
from .conf import SparkConf # noqa: D104
from .context import SparkContext
from .exception import ContributionsAcceptedError
from .sql import DataFrame, SparkSession
__all__ = ["ContributionsAcceptedError", "DataFrame", "SparkConf", "SparkContext", "SparkSession"]

View File

@@ -0,0 +1,77 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Module defining global singleton classes.
This module raises a RuntimeError if an attempt to reload it is made. In that
way the identities of the classes defined here are fixed and will remain so
even if duckdb spark itself is reloaded. In particular, a function like the following
will still work correctly after duckdb spark is reloaded:
def foo(arg=pyducdkb.spark._NoValue):
if arg is pyducdkb.spark._NoValue:
...
See gh-7844 for a discussion of the reload problem that motivated this module.
Note that this approach is taken after from NumPy.
"""
__ALL__ = ["_NoValue"]
# Disallow reloading this module so as to preserve the identities of the
# classes defined here.
if "_is_loaded" in globals():
msg = "Reloading duckdb.experimental.spark._globals is not allowed"
raise RuntimeError(msg)
_is_loaded = True
class _NoValueType:
"""Special keyword value.
The instance of this class may be used as the default value assigned to a
deprecated keyword in order to check if it has been given a user defined
value.
This class was copied from NumPy.
"""
__instance = None
def __new__(cls) -> "_NoValueType":
# ensure that only one instance exists
if not cls.__instance:
cls.__instance = super().__new__(cls)
return cls.__instance
# Make the _NoValue instance falsey
def __nonzero__(self) -> bool:
return False
__bool__ = __nonzero__
# needed for python 2 to preserve identity through a pickle
def __reduce__(self) -> tuple[type, tuple]:
return (self.__class__, ())
def __repr__(self) -> str:
return "<no value>"
_NoValue = _NoValueType()

View File

@@ -0,0 +1,46 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections.abc import Iterable, Sized
from typing import Callable, TypeVar, Union
from numpy import float32, float64, int32, int64, ndarray
from typing_extensions import Literal, Protocol, Self
F = TypeVar("F", bound=Callable)
T_co = TypeVar("T_co", covariant=True)
PrimitiveType = Union[bool, float, int, str]
NonUDFType = Literal[0]
class SupportsIAdd(Protocol):
def __iadd__(self, other: "SupportsIAdd") -> Self: ...
class SupportsOrdering(Protocol):
def __lt__(self, other: "SupportsOrdering") -> bool: ...
class SizedIterable(Protocol, Sized, Iterable[T_co]): ...
S = TypeVar("S", bound=SupportsOrdering)
NumberOrArray = TypeVar("NumberOrArray", float, int, complex, int32, int64, float32, float64, ndarray)

View File

@@ -0,0 +1,46 @@
from typing import Optional # noqa: D100
from duckdb.experimental.spark.exception import ContributionsAcceptedError
class SparkConf: # noqa: D101
def __init__(self) -> None: # noqa: D107
raise NotImplementedError
def contains(self, key: str) -> bool: # noqa: D102
raise ContributionsAcceptedError
def get(self, key: str, defaultValue: Optional[str] = None) -> Optional[str]: # noqa: D102
raise ContributionsAcceptedError
def getAll(self) -> list[tuple[str, str]]: # noqa: D102
raise ContributionsAcceptedError
def set(self, key: str, value: str) -> "SparkConf": # noqa: D102
raise ContributionsAcceptedError
def setAll(self, pairs: list[tuple[str, str]]) -> "SparkConf": # noqa: D102
raise ContributionsAcceptedError
def setAppName(self, value: str) -> "SparkConf": # noqa: D102
raise ContributionsAcceptedError
def setExecutorEnv( # noqa: D102
self, key: Optional[str] = None, value: Optional[str] = None, pairs: Optional[list[tuple[str, str]]] = None
) -> "SparkConf":
raise ContributionsAcceptedError
def setIfMissing(self, key: str, value: str) -> "SparkConf": # noqa: D102
raise ContributionsAcceptedError
def setMaster(self, value: str) -> "SparkConf": # noqa: D102
raise ContributionsAcceptedError
def setSparkHome(self, value: str) -> "SparkConf": # noqa: D102
raise ContributionsAcceptedError
def toDebugString(self) -> str: # noqa: D102
raise ContributionsAcceptedError
__all__ = ["SparkConf"]

View File

@@ -0,0 +1,180 @@
from typing import Optional # noqa: D100
import duckdb
from duckdb import DuckDBPyConnection
from duckdb.experimental.spark.conf import SparkConf
from duckdb.experimental.spark.exception import ContributionsAcceptedError
class SparkContext: # noqa: D101
def __init__(self, master: str) -> None: # noqa: D107
self._connection = duckdb.connect(":memory:")
# This aligns the null ordering with Spark.
self._connection.execute("set default_null_order='nulls_first_on_asc_last_on_desc'")
@property
def connection(self) -> DuckDBPyConnection: # noqa: D102
return self._connection
def stop(self) -> None: # noqa: D102
self._connection.close()
@classmethod
def getOrCreate(cls, conf: Optional[SparkConf] = None) -> "SparkContext": # noqa: D102
raise ContributionsAcceptedError
@classmethod
def setSystemProperty(cls, key: str, value: str) -> None: # noqa: D102
raise ContributionsAcceptedError
@property
def applicationId(self) -> str: # noqa: D102
raise ContributionsAcceptedError
@property
def defaultMinPartitions(self) -> int: # noqa: D102
raise ContributionsAcceptedError
@property
def defaultParallelism(self) -> int: # noqa: D102
raise ContributionsAcceptedError
# @property
# def resources(self) -> Dict[str, ResourceInformation]:
# raise ContributionsAcceptedError
@property
def startTime(self) -> str: # noqa: D102
raise ContributionsAcceptedError
@property
def uiWebUrl(self) -> str: # noqa: D102
raise ContributionsAcceptedError
@property
def version(self) -> str: # noqa: D102
raise ContributionsAcceptedError
def __repr__(self) -> str: # noqa: D105
raise ContributionsAcceptedError
# def accumulator(self, value: ~T, accum_param: Optional[ForwardRef('AccumulatorParam[T]')] = None
# ) -> 'Accumulator[T]':
# pass
def addArchive(self, path: str) -> None: # noqa: D102
raise ContributionsAcceptedError
def addFile(self, path: str, recursive: bool = False) -> None: # noqa: D102
raise ContributionsAcceptedError
def addPyFile(self, path: str) -> None: # noqa: D102
raise ContributionsAcceptedError
# def binaryFiles(self, path: str, minPartitions: Optional[int] = None
# ) -> duckdb.experimental.spark.rdd.RDD[typing.Tuple[str, bytes]]:
# pass
# def binaryRecords(self, path: str, recordLength: int) -> duckdb.experimental.spark.rdd.RDD[bytes]:
# pass
# def broadcast(self, value: ~T) -> 'Broadcast[T]':
# pass
def cancelAllJobs(self) -> None: # noqa: D102
raise ContributionsAcceptedError
def cancelJobGroup(self, groupId: str) -> None: # noqa: D102
raise ContributionsAcceptedError
def dump_profiles(self, path: str) -> None: # noqa: D102
raise ContributionsAcceptedError
# def emptyRDD(self) -> duckdb.experimental.spark.rdd.RDD[typing.Any]:
# pass
def getCheckpointDir(self) -> Optional[str]: # noqa: D102
raise ContributionsAcceptedError
def getConf(self) -> SparkConf: # noqa: D102
raise ContributionsAcceptedError
def getLocalProperty(self, key: str) -> Optional[str]: # noqa: D102
raise ContributionsAcceptedError
# def hadoopFile(self, path: str, inputFormatClass: str, keyClass: str, valueClass: str,
# keyConverter: Optional[str] = None, valueConverter: Optional[str] = None,
# conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]:
# pass
# def hadoopRDD(self, inputFormatClass: str, keyClass: str, valueClass: str, keyConverter: Optional[str] = None,
# valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0
# ) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]:
# pass
# def newAPIHadoopFile(self, path: str, inputFormatClass: str, keyClass: str, valueClass: str,
# keyConverter: Optional[str] = None, valueConverter: Optional[str] = None,
# conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]:
# pass
# def newAPIHadoopRDD(self, inputFormatClass: str, keyClass: str, valueClass: str,
# keyConverter: Optional[str] = None, valueConverter: Optional[str] = None,
# conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]:
# pass
# def parallelize(self, c: Iterable[~T], numSlices: Optional[int] = None) -> pyspark.rdd.RDD[~T]:
# pass
# def pickleFile(self, name: str, minPartitions: Optional[int] = None) -> pyspark.rdd.RDD[typing.Any]:
# pass
# def range(self, start: int, end: Optional[int] = None, step: int = 1, numSlices: Optional[int] = None
# ) -> pyspark.rdd.RDD[int]:
# pass
# def runJob(self, rdd: pyspark.rdd.RDD[~T], partitionFunc: Callable[[Iterable[~T]], Iterable[~U]],
# partitions: Optional[Sequence[int]] = None, allowLocal: bool = False) -> List[~U]:
# pass
# def sequenceFile(self, path: str, keyClass: Optional[str] = None, valueClass: Optional[str] = None,
# keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, minSplits: Optional[int] = None,
# batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]:
# pass
def setCheckpointDir(self, dirName: str) -> None: # noqa: D102
raise ContributionsAcceptedError
def setJobDescription(self, value: str) -> None: # noqa: D102
raise ContributionsAcceptedError
def setJobGroup(self, groupId: str, description: str, interruptOnCancel: bool = False) -> None: # noqa: D102
raise ContributionsAcceptedError
def setLocalProperty(self, key: str, value: str) -> None: # noqa: D102
raise ContributionsAcceptedError
def setLogLevel(self, logLevel: str) -> None: # noqa: D102
raise ContributionsAcceptedError
def show_profiles(self) -> None: # noqa: D102
raise ContributionsAcceptedError
def sparkUser(self) -> str: # noqa: D102
raise ContributionsAcceptedError
# def statusTracker(self) -> duckdb.experimental.spark.status.StatusTracker:
# raise ContributionsAcceptedError
# def textFile(self, name: str, minPartitions: Optional[int] = None, use_unicode: bool = True
# ) -> pyspark.rdd.RDD[str]:
# pass
# def union(self, rdds: List[pyspark.rdd.RDD[~T]]) -> pyspark.rdd.RDD[~T]:
# pass
# def wholeTextFiles(self, path: str, minPartitions: Optional[int] = None, use_unicode: bool = True
# ) -> pyspark.rdd.RDD[typing.Tuple[str, str]]:
# pass
__all__ = ["SparkContext"]

View File

@@ -0,0 +1,70 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""PySpark exceptions."""
from .exceptions.base import (
AnalysisException,
ArithmeticException,
ArrayIndexOutOfBoundsException,
DateTimeException,
IllegalArgumentException,
NumberFormatException,
ParseException,
PySparkAssertionError,
PySparkAttributeError,
PySparkException,
PySparkIndexError,
PySparkNotImplementedError,
PySparkRuntimeError,
PySparkTypeError,
PySparkValueError,
PythonException,
QueryExecutionException,
SparkRuntimeException,
SparkUpgradeException,
StreamingQueryException,
TempTableAlreadyExistsException,
UnknownException,
UnsupportedOperationException,
)
__all__ = [
"AnalysisException",
"ArithmeticException",
"ArrayIndexOutOfBoundsException",
"DateTimeException",
"IllegalArgumentException",
"NumberFormatException",
"ParseException",
"PySparkAssertionError",
"PySparkAttributeError",
"PySparkException",
"PySparkIndexError",
"PySparkNotImplementedError",
"PySparkRuntimeError",
"PySparkTypeError",
"PySparkValueError",
"PythonException",
"QueryExecutionException",
"SparkRuntimeException",
"SparkUpgradeException",
"StreamingQueryException",
"TempTableAlreadyExistsException",
"UnknownException",
"UnsupportedOperationException",
]

View File

@@ -0,0 +1,918 @@
# ruff: noqa: D100, E501
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
ERROR_CLASSES_JSON = """
{
"APPLICATION_NAME_NOT_SET" : {
"message" : [
"An application name must be set in your configuration."
]
},
"ARGUMENT_REQUIRED": {
"message": [
"Argument `<arg_name>` is required when <condition>."
]
},
"ATTRIBUTE_NOT_CALLABLE" : {
"message" : [
"Attribute `<attr_name>` in provided object `<obj_name>` is not callable."
]
},
"ATTRIBUTE_NOT_SUPPORTED" : {
"message" : [
"Attribute `<attr_name>` is not supported."
]
},
"AXIS_LENGTH_MISMATCH" : {
"message" : [
"Length mismatch: Expected axis has <expected_length> element, new values have <actual_length> elements."
]
},
"BROADCAST_VARIABLE_NOT_LOADED": {
"message": [
"Broadcast variable `<variable>` not loaded."
]
},
"CALL_BEFORE_INITIALIZE": {
"message": [
"Not supported to call `<func_name>` before initialize <object>."
]
},
"CANNOT_ACCEPT_OBJECT_IN_TYPE": {
"message": [
"`<data_type>` can not accept object `<obj_name>` in type `<obj_type>`."
]
},
"CANNOT_ACCESS_TO_DUNDER": {
"message": [
"Dunder(double underscore) attribute is for internal use only."
]
},
"CANNOT_APPLY_IN_FOR_COLUMN": {
"message": [
"Cannot apply 'in' operator against a column: please use 'contains' in a string column or 'array_contains' function for an array column."
]
},
"CANNOT_BE_EMPTY": {
"message": [
"At least one <item> must be specified."
]
},
"CANNOT_BE_NONE": {
"message": [
"Argument `<arg_name>` can not be None."
]
},
"CANNOT_CONVERT_COLUMN_INTO_BOOL": {
"message": [
"Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions."
]
},
"CANNOT_CONVERT_TYPE": {
"message": [
"Cannot convert <from_type> into <to_type>."
]
},
"CANNOT_DETERMINE_TYPE": {
"message": [
"Some of types cannot be determined after inferring."
]
},
"CANNOT_GET_BATCH_ID": {
"message": [
"Could not get batch id from <obj_name>."
]
},
"CANNOT_INFER_ARRAY_TYPE": {
"message": [
"Can not infer Array Type from an list with None as the first element."
]
},
"CANNOT_INFER_EMPTY_SCHEMA": {
"message": [
"Can not infer schema from empty dataset."
]
},
"CANNOT_INFER_SCHEMA_FOR_TYPE": {
"message": [
"Can not infer schema for type: `<data_type>`."
]
},
"CANNOT_INFER_TYPE_FOR_FIELD": {
"message": [
"Unable to infer the type of the field `<field_name>`."
]
},
"CANNOT_MERGE_TYPE": {
"message": [
"Can not merge type `<data_type1>` and `<data_type2>`."
]
},
"CANNOT_OPEN_SOCKET": {
"message": [
"Can not open socket: <errors>."
]
},
"CANNOT_PARSE_DATATYPE": {
"message": [
"Unable to parse datatype. <msg>."
]
},
"CANNOT_PROVIDE_METADATA": {
"message": [
"metadata can only be provided for a single column."
]
},
"CANNOT_SET_TOGETHER": {
"message": [
"<arg_list> should not be set together."
]
},
"CANNOT_SPECIFY_RETURN_TYPE_FOR_UDF": {
"message": [
"returnType can not be specified when `<arg_name>` is a user-defined function, but got <return_type>."
]
},
"COLUMN_IN_LIST": {
"message": [
"`<func_name>` does not allow a Column in a list."
]
},
"CONTEXT_ONLY_VALID_ON_DRIVER" : {
"message" : [
"It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063."
]
},
"CONTEXT_UNAVAILABLE_FOR_REMOTE_CLIENT" : {
"message" : [
"Remote client cannot create a SparkContext. Create SparkSession instead."
]
},
"DIFFERENT_PANDAS_DATAFRAME" : {
"message" : [
"DataFrames are not almost equal:",
"Left:",
"<left>",
"<left_dtype>",
"Right:",
"<right>",
"<right_dtype>"
]
},
"DIFFERENT_PANDAS_INDEX" : {
"message" : [
"Indices are not almost equal:",
"Left:",
"<left>",
"<left_dtype>",
"Right:",
"<right>",
"<right_dtype>"
]
},
"DIFFERENT_PANDAS_MULTIINDEX" : {
"message" : [
"MultiIndices are not almost equal:",
"Left:",
"<left>",
"<left_dtype>",
"Right:",
"<right>",
"<right_dtype>"
]
},
"DIFFERENT_PANDAS_SERIES" : {
"message" : [
"Series are not almost equal:",
"Left:",
"<left>",
"<left_dtype>",
"Right:",
"<right>",
"<right_dtype>"
]
},
"DIFFERENT_ROWS" : {
"message" : [
"<error_msg>"
]
},
"DIFFERENT_SCHEMA" : {
"message" : [
"Schemas do not match.",
"--- actual",
"+++ expected",
"<error_msg>"
]
},
"DISALLOWED_TYPE_FOR_CONTAINER" : {
"message" : [
"Argument `<arg_name>`(type: <arg_type>) should only contain a type in [<allowed_types>], got <return_type>"
]
},
"DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT" : {
"message" : [
"Duplicated field names in Arrow Struct are not allowed, got <field_names>"
]
},
"HIGHER_ORDER_FUNCTION_SHOULD_RETURN_COLUMN" : {
"message" : [
"Function `<func_name>` should return Column, got <return_type>."
]
},
"INCORRECT_CONF_FOR_PROFILE" : {
"message" : [
"`spark.python.profile` or `spark.python.profile.memory` configuration",
" must be set to `true` to enable Python profile."
]
},
"INVALID_ARROW_UDTF_RETURN_TYPE" : {
"message" : [
"The return type of the arrow-optimized Python UDTF should be of type 'pandas.DataFrame', but the '<func>' method returned a value of type <type_name> with value: <value>."
]
},
"INVALID_BROADCAST_OPERATION": {
"message": [
"Broadcast can only be <operation> in driver."
]
},
"INVALID_CALL_ON_UNRESOLVED_OBJECT": {
"message": [
"Invalid call to `<func_name>` on unresolved object."
]
},
"INVALID_CONNECT_URL" : {
"message" : [
"Invalid URL for Spark Connect: <detail>"
]
},
"INVALID_ITEM_FOR_CONTAINER": {
"message": [
"All items in `<arg_name>` should be in <allowed_types>, got <item_type>."
]
},
"INVALID_NDARRAY_DIMENSION": {
"message": [
"NumPy array input should be of <dimensions> dimensions."
]
},
"INVALID_PANDAS_UDF" : {
"message" : [
"Invalid function: <detail>"
]
},
"INVALID_PANDAS_UDF_TYPE" : {
"message" : [
"`<arg_name>` should be one the values from PandasUDFType, got <arg_type>"
]
},
"INVALID_RETURN_TYPE_FOR_PANDAS_UDF": {
"message": [
"Pandas UDF should return StructType for <eval_type>, got <return_type>."
]
},
"INVALID_TIMEOUT_TIMESTAMP" : {
"message" : [
"Timeout timestamp (<timestamp>) cannot be earlier than the current watermark (<watermark>)."
]
},
"INVALID_TYPE" : {
"message" : [
"Argument `<arg_name>` should not be a <data_type>."
]
},
"INVALID_TYPENAME_CALL" : {
"message" : [
"StructField does not have typeName. Use typeName on its type explicitly instead."
]
},
"INVALID_TYPE_DF_EQUALITY_ARG" : {
"message" : [
"Expected type <expected_type> for `<arg_name>` but got type <actual_type>."
]
},
"INVALID_UDF_EVAL_TYPE" : {
"message" : [
"Eval type for UDF must be <eval_type>."
]
},
"INVALID_UDTF_BOTH_RETURN_TYPE_AND_ANALYZE" : {
"message" : [
"The UDTF '<name>' is invalid. It has both its return type and an 'analyze' attribute. Please make it have one of either the return type or the 'analyze' static method in '<name>' and try again."
]
},
"INVALID_UDTF_EVAL_TYPE" : {
"message" : [
"The eval type for the UDTF '<name>' is invalid. It must be one of <eval_type>."
]
},
"INVALID_UDTF_HANDLER_TYPE" : {
"message" : [
"The UDTF is invalid. The function handler must be a class, but got '<type>'. Please provide a class as the function handler."
]
},
"INVALID_UDTF_NO_EVAL" : {
"message" : [
"The UDTF '<name>' is invalid. It does not implement the required 'eval' method. Please implement the 'eval' method in '<name>' and try again."
]
},
"INVALID_UDTF_RETURN_TYPE" : {
"message" : [
"The UDTF '<name>' is invalid. It does not specify its return type or implement the required 'analyze' static method. Please specify the return type or implement the 'analyze' static method in '<name>' and try again."
]
},
"INVALID_WHEN_USAGE": {
"message": [
"when() can only be applied on a Column previously generated by when() function, and cannot be applied once otherwise() is applied."
]
},
"INVALID_WINDOW_BOUND_TYPE" : {
"message" : [
"Invalid window bound type: <window_bound_type>."
]
},
"JAVA_GATEWAY_EXITED" : {
"message" : [
"Java gateway process exited before sending its port number."
]
},
"JVM_ATTRIBUTE_NOT_SUPPORTED" : {
"message" : [
"Attribute `<attr_name>` is not supported in Spark Connect as it depends on the JVM. If you need to use this attribute, do not use Spark Connect when creating your session."
]
},
"KEY_VALUE_PAIR_REQUIRED" : {
"message" : [
"Key-value pair or a list of pairs is required."
]
},
"LENGTH_SHOULD_BE_THE_SAME" : {
"message" : [
"<arg1> and <arg2> should be of the same length, got <arg1_length> and <arg2_length>."
]
},
"MASTER_URL_NOT_SET" : {
"message" : [
"A master URL must be set in your configuration."
]
},
"MISSING_LIBRARY_FOR_PROFILER" : {
"message" : [
"Install the 'memory_profiler' library in the cluster to enable memory profiling."
]
},
"MISSING_VALID_PLAN" : {
"message" : [
"Argument to <operator> does not contain a valid plan."
]
},
"MIXED_TYPE_REPLACEMENT" : {
"message" : [
"Mixed type replacements are not supported."
]
},
"NEGATIVE_VALUE" : {
"message" : [
"Value for `<arg_name>` must be greater than or equal to 0, got '<arg_value>'."
]
},
"NOT_BOOL" : {
"message" : [
"Argument `<arg_name>` should be a bool, got <arg_type>."
]
},
"NOT_BOOL_OR_DICT_OR_FLOAT_OR_INT_OR_LIST_OR_STR_OR_TUPLE" : {
"message" : [
"Argument `<arg_name>` should be a bool, dict, float, int, str or tuple, got <arg_type>."
]
},
"NOT_BOOL_OR_DICT_OR_FLOAT_OR_INT_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a bool, dict, float, int or str, got <arg_type>."
]
},
"NOT_BOOL_OR_FLOAT_OR_INT" : {
"message" : [
"Argument `<arg_name>` should be a bool, float or str, got <arg_type>."
]
},
"NOT_BOOL_OR_FLOAT_OR_INT_OR_LIST_OR_NONE_OR_STR_OR_TUPLE" : {
"message" : [
"Argument `<arg_name>` should be a bool, float, int, list, None, str or tuple, got <arg_type>."
]
},
"NOT_BOOL_OR_FLOAT_OR_INT_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a bool, float, int or str, got <arg_type>."
]
},
"NOT_BOOL_OR_LIST" : {
"message" : [
"Argument `<arg_name>` should be a bool or list, got <arg_type>."
]
},
"NOT_BOOL_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a bool or str, got <arg_type>."
]
},
"NOT_CALLABLE" : {
"message" : [
"Argument `<arg_name>` should be a callable, got <arg_type>."
]
},
"NOT_COLUMN" : {
"message" : [
"Argument `<arg_name>` should be a Column, got <arg_type>."
]
},
"NOT_COLUMN_OR_DATATYPE_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a Column, str or DataType, but got <arg_type>."
]
},
"NOT_COLUMN_OR_FLOAT_OR_INT_OR_LIST_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a column, float, integer, list or string, got <arg_type>."
]
},
"NOT_COLUMN_OR_INT" : {
"message" : [
"Argument `<arg_name>` should be a Column or int, got <arg_type>."
]
},
"NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE" : {
"message" : [
"Argument `<arg_name>` should be a Column, int, list, str or tuple, got <arg_type>."
]
},
"NOT_COLUMN_OR_INT_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a Column, int or str, got <arg_type>."
]
},
"NOT_COLUMN_OR_LIST_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a Column, list or str, got <arg_type>."
]
},
"NOT_COLUMN_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a Column or str, got <arg_type>."
]
},
"NOT_COLUMN_OR_STR_OR_STRUCT" : {
"message" : [
"Argument `<arg_name>` should be a StructType, Column or str, got <arg_type>."
]
},
"NOT_DATAFRAME" : {
"message" : [
"Argument `<arg_name>` should be a DataFrame, got <arg_type>."
]
},
"NOT_DATATYPE_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a DataType or str, got <arg_type>."
]
},
"NOT_DICT" : {
"message" : [
"Argument `<arg_name>` should be a dict, got <arg_type>."
]
},
"NOT_EXPRESSION" : {
"message" : [
"Argument `<arg_name>` should be a Expression, got <arg_type>."
]
},
"NOT_FLOAT_OR_INT" : {
"message" : [
"Argument `<arg_name>` should be a float or int, got <arg_type>."
]
},
"NOT_FLOAT_OR_INT_OR_LIST_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a float, int, list or str, got <arg_type>."
]
},
"NOT_IMPLEMENTED" : {
"message" : [
"<feature> is not implemented."
]
},
"NOT_INSTANCE_OF" : {
"message" : [
"<value> is not an instance of type <data_type>."
]
},
"NOT_INT" : {
"message" : [
"Argument `<arg_name>` should be an int, got <arg_type>."
]
},
"NOT_INT_OR_SLICE_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be an int, slice or str, got <arg_type>."
]
},
"NOT_IN_BARRIER_STAGE" : {
"message" : [
"It is not in a barrier stage."
]
},
"NOT_ITERABLE" : {
"message" : [
"<objectName> is not iterable."
]
},
"NOT_LIST" : {
"message" : [
"Argument `<arg_name>` should be a list, got <arg_type>."
]
},
"NOT_LIST_OF_COLUMN" : {
"message" : [
"Argument `<arg_name>` should be a list[Column]."
]
},
"NOT_LIST_OF_COLUMN_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a list[Column]."
]
},
"NOT_LIST_OF_FLOAT_OR_INT" : {
"message" : [
"Argument `<arg_name>` should be a list[float, int], got <arg_type>."
]
},
"NOT_LIST_OF_STR" : {
"message" : [
"Argument `<arg_name>` should be a list[str], got <arg_type>."
]
},
"NOT_LIST_OR_NONE_OR_STRUCT" : {
"message" : [
"Argument `<arg_name>` should be a list, None or StructType, got <arg_type>."
]
},
"NOT_LIST_OR_STR_OR_TUPLE" : {
"message" : [
"Argument `<arg_name>` should be a list, str or tuple, got <arg_type>."
]
},
"NOT_LIST_OR_TUPLE" : {
"message" : [
"Argument `<arg_name>` should be a list or tuple, got <arg_type>."
]
},
"NOT_NUMERIC_COLUMNS" : {
"message" : [
"Numeric aggregation function can only be applied on numeric columns, got <invalid_columns>."
]
},
"NOT_OBSERVATION_OR_STR" : {
"message" : [
"Argument `<arg_name>` should be a Observation or str, got <arg_type>."
]
},
"NOT_SAME_TYPE" : {
"message" : [
"Argument `<arg_name1>` and `<arg_name2>` should be the same type, got <arg_type1> and <arg_type2>."
]
},
"NOT_STR" : {
"message" : [
"Argument `<arg_name>` should be a str, got <arg_type>."
]
},
"NOT_STR_OR_LIST_OF_RDD" : {
"message" : [
"Argument `<arg_name>` should be a str or list[RDD], got <arg_type>."
]
},
"NOT_STR_OR_STRUCT" : {
"message" : [
"Argument `<arg_name>` should be a str or structType, got <arg_type>."
]
},
"NOT_WINDOWSPEC" : {
"message" : [
"Argument `<arg_name>` should be a WindowSpec, got <arg_type>."
]
},
"NO_ACTIVE_OR_DEFAULT_SESSION" : {
"message" : [
"No active or default Spark session found. Please create a new Spark session before running the code."
]
},
"NO_ACTIVE_SESSION" : {
"message" : [
"No active Spark session found. Please create a new Spark session before running the code."
]
},
"ONLY_ALLOWED_FOR_SINGLE_COLUMN" : {
"message" : [
"Argument `<arg_name>` can only be provided for a single column."
]
},
"ONLY_ALLOW_SINGLE_TRIGGER" : {
"message" : [
"Only a single trigger is allowed."
]
},
"PIPE_FUNCTION_EXITED" : {
"message" : [
"Pipe function `<func_name>` exited with error code <error_code>."
]
},
"PYTHON_HASH_SEED_NOT_SET" : {
"message" : [
"Randomness of hash of string should be disabled via PYTHONHASHSEED."
]
},
"PYTHON_VERSION_MISMATCH" : {
"message" : [
"Python in worker has different version: <worker_version> than that in driver: <driver_version>, PySpark cannot run with different minor versions.",
"Please check environment variables PYSPARK_PYTHON and PYSPARK_DRIVER_PYTHON are correctly set."
]
},
"RDD_TRANSFORM_ONLY_VALID_ON_DRIVER" : {
"message" : [
"It appears that you are attempting to broadcast an RDD or reference an RDD from an ",
"action or transformation. RDD transformations and actions can only be invoked by the ",
"driver, not inside of other transformations; for example, ",
"rdd1.map(lambda x: rdd2.values.count() * x) is invalid because the values ",
"transformation and count action cannot be performed inside of the rdd1.map ",
"transformation. For more information, see SPARK-5063."
]
},
"RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF" : {
"message" : [
"Column names of the returned pandas.DataFrame do not match specified schema.<missing><extra>"
]
},
"RESULT_LENGTH_MISMATCH_FOR_PANDAS_UDF" : {
"message" : [
"Number of columns of the returned pandas.DataFrame doesn't match specified schema. Expected: <expected> Actual: <actual>"
]
},
"RESULT_LENGTH_MISMATCH_FOR_SCALAR_ITER_PANDAS_UDF" : {
"message" : [
"The length of output in Scalar iterator pandas UDF should be the same with the input's; however, the length of output was <output_length> and the length of input was <input_length>."
]
},
"SCHEMA_MISMATCH_FOR_PANDAS_UDF" : {
"message" : [
"Result vector from pandas_udf was not the required length: expected <expected>, got <actual>."
]
},
"SESSION_ALREADY_EXIST" : {
"message" : [
"Cannot start a remote Spark session because there is a regular Spark session already running."
]
},
"SESSION_NOT_SAME" : {
"message" : [
"Both Datasets must belong to the same SparkSession."
]
},
"SESSION_OR_CONTEXT_EXISTS" : {
"message" : [
"There should not be an existing Spark Session or Spark Context."
]
},
"SHOULD_NOT_DATAFRAME": {
"message": [
"Argument `<arg_name>` should not be a DataFrame."
]
},
"SLICE_WITH_STEP" : {
"message" : [
"Slice with step is not supported."
]
},
"STATE_NOT_EXISTS" : {
"message" : [
"State is either not defined or has already been removed."
]
},
"STOP_ITERATION_OCCURRED" : {
"message" : [
"Caught StopIteration thrown from user's code; failing the task: <exc>"
]
},
"STOP_ITERATION_OCCURRED_FROM_SCALAR_ITER_PANDAS_UDF" : {
"message" : [
"pandas iterator UDF should exhaust the input iterator."
]
},
"STREAMING_CONNECT_SERIALIZATION_ERROR" : {
"message" : [
"Cannot serialize the function `<name>`. If you accessed the Spark session, or a DataFrame defined outside of the function, or any object that contains a Spark session, please be aware that they are not allowed in Spark Connect. For `foreachBatch`, please access the Spark session using `df.sparkSession`, where `df` is the first parameter in your `foreachBatch` function. For `StreamingQueryListener`, please access the Spark session using `self.spark`. For details please check out the PySpark doc for `foreachBatch` and `StreamingQueryListener`."
]
},
"TOO_MANY_VALUES" : {
"message" : [
"Expected <expected> values for `<item>`, got <actual>."
]
},
"UDF_RETURN_TYPE" : {
"message" : [
"Return type of the user-defined function should be <expected>, but is <actual>."
]
},
"UDTF_ARROW_TYPE_CAST_ERROR" : {
"message" : [
"Cannot convert the output value of the column '<col_name>' with type '<col_type>' to the specified return type of the column: '<arrow_type>'. Please check if the data types match and try again."
]
},
"UDTF_EXEC_ERROR" : {
"message" : [
"User defined table function encountered an error in the '<method_name>' method: <error>"
]
},
"UDTF_INVALID_OUTPUT_ROW_TYPE" : {
"message" : [
"The type of an individual output row in the '<func>' method of the UDTF is invalid. Each row should be a tuple, list, or dict, but got '<type>'. Please make sure that the output rows are of the correct type."
]
},
"UDTF_RETURN_NOT_ITERABLE" : {
"message" : [
"The return value of the '<func>' method of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got '<type>'. Please make sure that the UDTF returns one of these types."
]
},
"UDTF_RETURN_SCHEMA_MISMATCH" : {
"message" : [
"The number of columns in the result does not match the specified schema. Expected column count: <expected>, Actual column count: <actual>. Please make sure the values returned by the '<func>' method have the same number of columns as specified in the output schema."
]
},
"UDTF_RETURN_TYPE_MISMATCH" : {
"message" : [
"Mismatch in return type for the UDTF '<name>'. Expected a 'StructType', but got '<return_type>'. Please ensure the return type is a correctly formatted StructType."
]
},
"UDTF_SERIALIZATION_ERROR" : {
"message" : [
"Cannot serialize the UDTF '<name>': <message>"
]
},
"UNEXPECTED_RESPONSE_FROM_SERVER" : {
"message" : [
"Unexpected response from iterator server."
]
},
"UNEXPECTED_TUPLE_WITH_STRUCT" : {
"message" : [
"Unexpected tuple <tuple> with StructType."
]
},
"UNKNOWN_EXPLAIN_MODE" : {
"message" : [
"Unknown explain mode: '<explain_mode>'. Accepted explain modes are 'simple', 'extended', 'codegen', 'cost', 'formatted'."
]
},
"UNKNOWN_INTERRUPT_TYPE" : {
"message" : [
"Unknown interrupt type: '<interrupt_type>'. Accepted interrupt types are 'all'."
]
},
"UNKNOWN_RESPONSE" : {
"message" : [
"Unknown response: <response>."
]
},
"UNSUPPORTED_DATA_TYPE" : {
"message" : [
"Unsupported DataType `<data_type>`."
]
},
"UNSUPPORTED_DATA_TYPE_FOR_ARROW" : {
"message" : [
"Single data type <data_type> is not supported with Arrow."
]
},
"UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION" : {
"message" : [
"<data_type> is not supported in conversion to Arrow."
]
},
"UNSUPPORTED_DATA_TYPE_FOR_ARROW_VERSION" : {
"message" : [
"<data_type> is only supported with pyarrow 2.0.0 and above."
]
},
"UNSUPPORTED_JOIN_TYPE" : {
"message" : [
"Unsupported join type: <join_type>. Supported join types include: \\"inner\\", \\"outer\\", \\"full\\", \\"fullouter\\", \\"full_outer\\", \\"leftouter\\", \\"left\\", \\"left_outer\\", \\"rightouter\\", \\"right\\", \\"right_outer\\", \\"leftsemi\\", \\"left_semi\\", \\"semi\\", \\"leftanti\\", \\"left_anti\\", \\"anti\\", \\"cross\\"."
]
},
"UNSUPPORTED_LITERAL" : {
"message" : [
"Unsupported Literal '<literal>'."
]
},
"UNSUPPORTED_NUMPY_ARRAY_SCALAR" : {
"message" : [
"The type of array scalar '<dtype>' is not supported."
]
},
"UNSUPPORTED_OPERATION" : {
"message" : [
"<operation> is not supported."
]
},
"UNSUPPORTED_PARAM_TYPE_FOR_HIGHER_ORDER_FUNCTION" : {
"message" : [
"Function `<func_name>` should use only POSITIONAL or POSITIONAL OR KEYWORD arguments."
]
},
"UNSUPPORTED_SIGNATURE" : {
"message" : [
"Unsupported signature: <signature>."
]
},
"UNSUPPORTED_WITH_ARROW_OPTIMIZATION" : {
"message" : [
"<feature> is not supported with Arrow optimization enabled in Python UDFs. Disable 'spark.sql.execution.pythonUDF.arrow.enabled' to workaround.."
]
},
"VALUE_NOT_ACCESSIBLE": {
"message": [
"Value `<value>` cannot be accessed inside tasks."
]
},
"VALUE_NOT_ALLOWED" : {
"message" : [
"Value for `<arg_name>` has to be amongst the following values: <allowed_values>."
]
},
"VALUE_NOT_ANY_OR_ALL" : {
"message" : [
"Value for `<arg_name>` must be 'any' or 'all', got '<arg_value>'."
]
},
"VALUE_NOT_BETWEEN" : {
"message" : [
"Value for `<arg_name>` must be between <min> and <max>."
]
},
"VALUE_NOT_NON_EMPTY_STR" : {
"message" : [
"Value for `<arg_name>` must be a non empty string, got '<arg_value>'."
]
},
"VALUE_NOT_PEARSON" : {
"message" : [
"Value for `<arg_name>` only supports the 'pearson', got '<arg_value>'."
]
},
"VALUE_NOT_POSITIVE" : {
"message" : [
"Value for `<arg_name>` must be positive, got '<arg_value>'."
]
},
"VALUE_NOT_TRUE" : {
"message" : [
"Value for `<arg_name>` must be True, got '<arg_value>'."
]
},
"VALUE_OUT_OF_BOUND" : {
"message" : [
"Value for `<arg_name>` must be greater than <lower_bound> or less than <upper_bound>, got <actual>"
]
},
"WRONG_NUM_ARGS_FOR_HIGHER_ORDER_FUNCTION" : {
"message" : [
"Function `<func_name>` should take between 1 and 3 arguments, but provided function takes <num_args>."
]
},
"WRONG_NUM_COLUMNS" : {
"message" : [
"Function `<func_name>` should take at least <num_cols> columns."
]
},
"ZERO_INDEX": {
"message": [
"Index must be non-zero."
]
}
}
"""
ERROR_CLASSES_MAP = json.loads(ERROR_CLASSES_JSON)

View File

@@ -0,0 +1,16 @@
# # noqa: D104
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

View File

@@ -0,0 +1,168 @@
from typing import Optional, cast # noqa: D100
from ..utils import ErrorClassesReader
class PySparkException(Exception):
"""Base Exception for handling errors generated from PySpark."""
def __init__( # noqa: D107
self,
message: Optional[str] = None,
# The error class, decides the message format, must be one of the valid options listed in 'error_classes.py'
error_class: Optional[str] = None,
# The dictionary listing the arguments specified in the message (or the error_class)
message_parameters: Optional[dict[str, str]] = None,
) -> None:
# `message` vs `error_class` & `message_parameters` are mutually exclusive.
assert (message is not None and (error_class is None and message_parameters is None)) or (
message is None and (error_class is not None and message_parameters is not None)
)
self.error_reader = ErrorClassesReader()
if message is None:
self.message = self.error_reader.get_error_message(
cast("str", error_class), cast("dict[str, str]", message_parameters)
)
else:
self.message = message
self.error_class = error_class
self.message_parameters = message_parameters
def getErrorClass(self) -> Optional[str]:
"""Returns an error class as a string.
.. versionadded:: 3.4.0
See Also:
--------
:meth:`PySparkException.getMessageParameters`
:meth:`PySparkException.getSqlState`
"""
return self.error_class
def getMessageParameters(self) -> Optional[dict[str, str]]:
"""Returns a message parameters as a dictionary.
.. versionadded:: 3.4.0
See Also:
--------
:meth:`PySparkException.getErrorClass`
:meth:`PySparkException.getSqlState`
"""
return self.message_parameters
def getSqlState(self) -> None:
"""Returns an SQLSTATE as a string.
Errors generated in Python have no SQLSTATE, so it always returns None.
.. versionadded:: 3.4.0
See Also:
--------
:meth:`PySparkException.getErrorClass`
:meth:`PySparkException.getMessageParameters`
"""
return None
def __str__(self) -> str: # noqa: D105
if self.getErrorClass() is not None:
return f"[{self.getErrorClass()}] {self.message}"
else:
return self.message
class AnalysisException(PySparkException):
"""Failed to analyze a SQL query plan."""
class SessionNotSameException(PySparkException):
"""Performed the same operation on different SparkSession."""
class TempTableAlreadyExistsException(AnalysisException):
"""Failed to create temp view since it is already exists."""
class ParseException(AnalysisException):
"""Failed to parse a SQL command."""
class IllegalArgumentException(PySparkException):
"""Passed an illegal or inappropriate argument."""
class ArithmeticException(PySparkException):
"""Arithmetic exception thrown from Spark with an error class."""
class UnsupportedOperationException(PySparkException):
"""Unsupported operation exception thrown from Spark with an error class."""
class ArrayIndexOutOfBoundsException(PySparkException):
"""Array index out of bounds exception thrown from Spark with an error class."""
class DateTimeException(PySparkException):
"""Datetime exception thrown from Spark with an error class."""
class NumberFormatException(IllegalArgumentException):
"""Number format exception thrown from Spark with an error class."""
class StreamingQueryException(PySparkException):
"""Exception that stopped a :class:`StreamingQuery`."""
class QueryExecutionException(PySparkException):
"""Failed to execute a query."""
class PythonException(PySparkException):
"""Exceptions thrown from Python workers."""
class SparkRuntimeException(PySparkException):
"""Runtime exception thrown from Spark with an error class."""
class SparkUpgradeException(PySparkException):
"""Exception thrown because of Spark upgrade."""
class UnknownException(PySparkException):
"""None of the above exceptions."""
class PySparkValueError(PySparkException, ValueError):
"""Wrapper class for ValueError to support error classes."""
class PySparkIndexError(PySparkException, IndexError):
"""Wrapper class for IndexError to support error classes."""
class PySparkTypeError(PySparkException, TypeError):
"""Wrapper class for TypeError to support error classes."""
class PySparkAttributeError(PySparkException, AttributeError):
"""Wrapper class for AttributeError to support error classes."""
class PySparkRuntimeError(PySparkException, RuntimeError):
"""Wrapper class for RuntimeError to support error classes."""
class PySparkAssertionError(PySparkException, AssertionError):
"""Wrapper class for AssertionError to support error classes."""
class PySparkNotImplementedError(PySparkException, NotImplementedError):
"""Wrapper class for NotImplementedError to support error classes."""

View File

@@ -0,0 +1,111 @@
# # noqa: D100
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from .error_classes import ERROR_CLASSES_MAP
class ErrorClassesReader:
"""A reader to load error information from error_classes.py."""
def __init__(self) -> None: # noqa: D107
self.error_info_map = ERROR_CLASSES_MAP
def get_error_message(self, error_class: str, message_parameters: dict[str, str]) -> str:
"""Returns the completed error message by applying message parameters to the message template."""
message_template = self.get_message_template(error_class)
# Verify message parameters.
message_parameters_from_template = re.findall("<([a-zA-Z0-9_-]+)>", message_template)
assert set(message_parameters_from_template) == set(message_parameters), (
f"Undefined error message parameter for error class: {error_class}. Parameters: {message_parameters}"
)
table = str.maketrans("<>", "{}")
return message_template.translate(table).format(**message_parameters)
def get_message_template(self, error_class: str) -> str:
"""Returns the message template for corresponding error class from error_classes.py.
For example,
when given `error_class` is "EXAMPLE_ERROR_CLASS",
and corresponding error class in error_classes.py looks like the below:
.. code-block:: python
"EXAMPLE_ERROR_CLASS" : {
"message" : [
"Problem <A> because of <B>."
]
}
In this case, this function returns:
"Problem <A> because of <B>."
For sub error class, when given `error_class` is "EXAMPLE_ERROR_CLASS.SUB_ERROR_CLASS",
and corresponding error class in error_classes.py looks like the below:
.. code-block:: python
"EXAMPLE_ERROR_CLASS" : {
"message" : [
"Problem <A> because of <B>."
],
"sub_class" : {
"SUB_ERROR_CLASS" : {
"message" : [
"Do <C> to fix the problem."
]
}
}
}
In this case, this function returns:
"Problem <A> because <B>. Do <C> to fix the problem."
"""
error_classes = error_class.split(".")
len_error_classes = len(error_classes)
assert len_error_classes in (1, 2)
# Generate message template for main error class.
main_error_class = error_classes[0]
if main_error_class in self.error_info_map:
main_error_class_info_map = self.error_info_map[main_error_class]
else:
msg = f"Cannot find main error class '{main_error_class}'"
raise ValueError(msg)
main_message_template = "\n".join(main_error_class_info_map["message"])
has_sub_class = len_error_classes == 2
if not has_sub_class:
message_template = main_message_template
else:
# Generate message template for sub error class if exists.
sub_error_class = error_classes[1]
main_error_class_subclass_info_map = main_error_class_info_map["sub_class"]
if sub_error_class in main_error_class_subclass_info_map:
sub_error_class_info_map = main_error_class_subclass_info_map[sub_error_class]
else:
msg = f"Cannot find sub error class '{sub_error_class}'"
raise ValueError(msg)
sub_message_template = "\n".join(sub_error_class_info_map["message"])
message_template = main_message_template + " " + sub_message_template
return message_template

View File

@@ -0,0 +1,18 @@
# ruff: noqa: D100
from typing import Optional
class ContributionsAcceptedError(NotImplementedError):
"""This method is not planned to be implemented, if you would like to implement this method
or show your interest in this method to other members of the community,
feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb.
""" # noqa: D205
def __init__(self, message: Optional[str] = None) -> None: # noqa: D107
doc = self.__class__.__doc__
if message:
doc = message + "\n" + doc
super().__init__(doc)
__all__ = ["ContributionsAcceptedError"]

View File

@@ -0,0 +1,7 @@
from .catalog import Catalog # noqa: D104
from .conf import RuntimeConfig
from .dataframe import DataFrame
from .readwriter import DataFrameWriter
from .session import SparkSession
__all__ = ["Catalog", "DataFrame", "DataFrameWriter", "RuntimeConfig", "SparkSession"]

View File

@@ -0,0 +1,86 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import (
Any,
Callable,
Optional,
TypeVar,
Union,
)
try:
from typing import Literal, Protocol
except ImportError:
from typing_extensions import Literal, Protocol
import datetime
import decimal
from .._typing import PrimitiveType
from . import types
from .column import Column
ColumnOrName = Union[Column, str]
ColumnOrName_ = TypeVar("ColumnOrName_", bound=ColumnOrName)
DecimalLiteral = decimal.Decimal
DateTimeLiteral = Union[datetime.datetime, datetime.date]
LiteralType = PrimitiveType
AtomicDataTypeOrString = Union[types.AtomicType, str]
DataTypeOrString = Union[types.DataType, str]
OptionalPrimitiveType = Optional[PrimitiveType]
AtomicValue = TypeVar(
"AtomicValue",
datetime.datetime,
datetime.date,
decimal.Decimal,
bool,
str,
int,
float,
)
RowLike = TypeVar("RowLike", list[Any], tuple[Any, ...], types.Row)
SQLBatchedUDFType = Literal[100]
class SupportsOpen(Protocol):
def open(self, partition_id: int, epoch_id: int) -> bool: ...
class SupportsProcess(Protocol):
def process(self, row: types.Row) -> None: ...
class SupportsClose(Protocol):
def close(self, error: Exception) -> None: ...
class UserDefinedFunctionLike(Protocol):
func: Callable[..., Any]
evalType: int
deterministic: bool
@property
def returnType(self) -> types.DataType: ...
def __call__(self, *args: ColumnOrName) -> Column: ...
def asNondeterministic(self) -> "UserDefinedFunctionLike": ...

View File

@@ -0,0 +1,79 @@
from typing import NamedTuple, Optional, Union # noqa: D100
from .session import SparkSession
class Database(NamedTuple): # noqa: D101
name: str
description: Optional[str]
locationUri: str
class Table(NamedTuple): # noqa: D101
name: str
database: Optional[str]
description: Optional[str]
tableType: str
isTemporary: bool
class Column(NamedTuple): # noqa: D101
name: str
description: Optional[str]
dataType: str
nullable: bool
isPartition: bool
isBucket: bool
class Function(NamedTuple): # noqa: D101
name: str
description: Optional[str]
className: str
isTemporary: bool
class Catalog: # noqa: D101
def __init__(self, session: SparkSession) -> None: # noqa: D107
self._session = session
def listDatabases(self) -> list[Database]: # noqa: D102
res = self._session.conn.sql("select database_name from duckdb_databases()").fetchall()
def transform_to_database(x: list[str]) -> Database:
return Database(name=x[0], description=None, locationUri="")
databases = [transform_to_database(x) for x in res]
return databases
def listTables(self) -> list[Table]: # noqa: D102
res = self._session.conn.sql("select table_name, database_name, sql, temporary from duckdb_tables()").fetchall()
def transform_to_table(x: list[str]) -> Table:
return Table(name=x[0], database=x[1], description=x[2], tableType="", isTemporary=x[3])
tables = [transform_to_table(x) for x in res]
return tables
def listColumns(self, tableName: str, dbName: Optional[str] = None) -> list[Column]: # noqa: D102
query = f"""
select column_name, data_type, is_nullable from duckdb_columns() where table_name = '{tableName}'
"""
if dbName:
query += f" and database_name = '{dbName}'"
res = self._session.conn.sql(query).fetchall()
def transform_to_column(x: list[Union[str, bool]]) -> Column:
return Column(name=x[0], description=None, dataType=x[1], nullable=x[2], isPartition=False, isBucket=False)
columns = [transform_to_column(x) for x in res]
return columns
def listFunctions(self, dbName: Optional[str] = None) -> list[Function]: # noqa: D102
raise NotImplementedError
def setCurrentDatabase(self, dbName: str) -> None: # noqa: D102
raise NotImplementedError
__all__ = ["Catalog", "Column", "Database", "Function", "Table"]

View File

@@ -0,0 +1,361 @@
from collections.abc import Iterable # noqa: D100
from typing import TYPE_CHECKING, Any, Callable, Union, cast
from ..exception import ContributionsAcceptedError
from .types import DataType
if TYPE_CHECKING:
from ._typing import DateTimeLiteral, DecimalLiteral, LiteralType
from duckdb import ColumnExpression, ConstantExpression, Expression, FunctionExpression
from duckdb.sqltypes import DuckDBPyType
__all__ = ["Column"]
def _get_expr(x: Union["Column", str]) -> Expression:
return x.expr if isinstance(x, Column) else ConstantExpression(x)
def _func_op(name: str, doc: str = "") -> Callable[["Column"], "Column"]:
def _(self: "Column") -> "Column":
njc = getattr(self.expr, name)()
return Column(njc)
_.__doc__ = doc
return _
def _unary_op(
name: str,
doc: str = "unary operator",
) -> Callable[["Column"], "Column"]:
"""Create a method for given unary operator."""
def _(self: "Column") -> "Column":
# Call the function identified by 'name' on the internal Expression object
expr = getattr(self.expr, name)()
return Column(expr)
_.__doc__ = doc
return _
def _bin_op(
name: str,
doc: str = "binary operator",
) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"]], "Column"]:
"""Create a method for given binary operator."""
def _(
self: "Column",
other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
) -> "Column":
jc = _get_expr(other)
njc = getattr(self.expr, name)(jc)
return Column(njc)
_.__doc__ = doc
return _
def _bin_func(
name: str,
doc: str = "binary function",
) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"]], "Column"]:
"""Create a function expression for the given binary function."""
def _(
self: "Column",
other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
) -> "Column":
other = _get_expr(other)
func = FunctionExpression(name, self.expr, other)
return Column(func)
_.__doc__ = doc
return _
class Column:
"""A column in a DataFrame.
:class:`Column` instances can be created by::
# 1. Select a column out of a DataFrame
df.colName
df["colName"]
# 2. Create from an expression
df.colName + 1
1 / df.colName
.. versionadded:: 1.3.0
"""
def __init__(self, expr: Expression) -> None: # noqa: D107
self.expr = expr
# arithmetic operators
def __neg__(self) -> "Column": # noqa: D105
return Column(-self.expr)
# `and`, `or`, `not` cannot be overloaded in Python,
# so use bitwise operators as boolean operators
__and__ = _bin_op("__and__")
__or__ = _bin_op("__or__")
__invert__ = _func_op("__invert__")
__rand__ = _bin_op("__rand__")
__ror__ = _bin_op("__ror__")
__add__ = _bin_op("__add__")
__sub__ = _bin_op("__sub__")
__mul__ = _bin_op("__mul__")
__div__ = _bin_op("__div__")
__truediv__ = _bin_op("__truediv__")
__mod__ = _bin_op("__mod__")
__pow__ = _bin_op("__pow__")
__radd__ = _bin_op("__radd__")
__rsub__ = _bin_op("__rsub__")
__rmul__ = _bin_op("__rmul__")
__rdiv__ = _bin_op("__rdiv__")
__rtruediv__ = _bin_op("__rtruediv__")
__rmod__ = _bin_op("__rmod__")
__rpow__ = _bin_op("__rpow__")
def __getitem__(self, k: Any) -> "Column": # noqa: ANN401
"""An expression that gets an item at position ``ordinal`` out of a list,
or gets an item by key out of a dict.
.. versionadded:: 1.3.0
.. versionchanged:: 3.4.0
Supports Spark Connect.
Parameters
----------
k
a literal value, or a slice object without step.
Returns:
-------
:class:`Column`
Column representing the item got by key out of a dict, or substrings sliced by
the given slice object.
Examples:
--------
>>> df = spark.createDataFrame([("abcedfg", {"key": "value"})], ["l", "d"])
>>> df.select(df.l[slice(1, 3)], df.d["key"]).show()
+------------------+------+
|substring(l, 1, 3)|d[key]|
+------------------+------+
| abc| value|
+------------------+------+
""" # noqa: D205
if isinstance(k, slice):
raise ContributionsAcceptedError
# if k.step is not None:
# raise ValueError("Using a slice with a step value is not supported")
# return self.substr(k.start, k.stop)
else:
# TODO: this is super hacky # noqa: TD002, TD003
expr_str = str(self.expr) + "." + str(k)
return Column(ColumnExpression(expr_str))
def __getattr__(self, item: Any) -> "Column": # noqa: ANN401
"""An expression that gets an item at position ``ordinal`` out of a list,
or gets an item by key out of a dict.
Parameters
----------
item
a literal value.
Returns:
-------
:class:`Column`
Column representing the item got by key out of a dict.
Examples:
--------
>>> df = spark.createDataFrame([("abcedfg", {"key": "value"})], ["l", "d"])
>>> df.select(df.d.key).show()
+------+
|d[key]|
+------+
| value|
+------+
""" # noqa: D205
if item.startswith("__"):
msg = "Can not access __ (dunder) method"
raise AttributeError(msg)
return self[item]
def alias(self, alias: str) -> "Column": # noqa: D102
return Column(self.expr.alias(alias))
def when(self, condition: "Column", value: Union["Column", str]) -> "Column": # noqa: D102
if not isinstance(condition, Column):
msg = "condition should be a Column"
raise TypeError(msg)
v = _get_expr(value)
expr = self.expr.when(condition.expr, v)
return Column(expr)
def otherwise(self, value: Union["Column", str]) -> "Column": # noqa: D102
v = _get_expr(value)
expr = self.expr.otherwise(v)
return Column(expr)
def cast(self, dataType: Union[DataType, str]) -> "Column": # noqa: D102
internal_type = DuckDBPyType(dataType) if isinstance(dataType, str) else dataType.duckdb_type
return Column(self.expr.cast(internal_type))
def isin(self, *cols: Union[Iterable[Union["Column", str]], Union["Column", str]]) -> "Column": # noqa: D102
if len(cols) == 1 and isinstance(cols[0], (list, set)):
# Only one argument supplied, it's a list
cols = cast("tuple", cols[0])
cols = cast(
"tuple",
[_get_expr(c) for c in cols],
)
return Column(self.expr.isin(*cols))
# logistic operators
def __eq__( # type: ignore[override]
self,
other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
) -> "Column":
"""Binary function."""
return Column(self.expr == (_get_expr(other)))
def __ne__( # type: ignore[override]
self,
other: object,
) -> "Column":
"""Binary function."""
return Column(self.expr != (_get_expr(other)))
__lt__ = _bin_op("__lt__")
__le__ = _bin_op("__le__")
__ge__ = _bin_op("__ge__")
__gt__ = _bin_op("__gt__")
# String interrogation methods
contains = _bin_func("contains")
rlike = _bin_func("regexp_matches")
like = _bin_func("~~")
ilike = _bin_func("~~*")
startswith = _bin_func("starts_with")
endswith = _bin_func("suffix")
# order
_asc_doc = """
Returns a sort expression based on the ascending order of the column.
Examples
--------
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.asc()).collect()
[Row(name='Alice'), Row(name='Tom')]
"""
_asc_nulls_first_doc = """
Returns a sort expression based on ascending order of the column, and null values
return before non-null values.
Examples
--------
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.asc_nulls_first()).collect()
[Row(name=None), Row(name='Alice'), Row(name='Tom')]
"""
_asc_nulls_last_doc = """
Returns a sort expression based on ascending order of the column, and null values
appear after non-null values.
Examples
--------
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.asc_nulls_last()).collect()
[Row(name='Alice'), Row(name='Tom'), Row(name=None)]
"""
_desc_doc = """
Returns a sort expression based on the descending order of the column.
Examples
--------
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.desc()).collect()
[Row(name='Tom'), Row(name='Alice')]
"""
_desc_nulls_first_doc = """
Returns a sort expression based on the descending order of the column, and null values
appear before non-null values.
Examples
--------
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.desc_nulls_first()).collect()
[Row(name=None), Row(name='Tom'), Row(name='Alice')]
"""
_desc_nulls_last_doc = """
Returns a sort expression based on the descending order of the column, and null values
appear after non-null values.
Examples
--------
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.desc_nulls_last()).collect()
[Row(name='Tom'), Row(name='Alice'), Row(name=None)]
"""
asc = _unary_op("asc", _asc_doc)
desc = _unary_op("desc", _desc_doc)
nulls_first = _unary_op("nulls_first")
nulls_last = _unary_op("nulls_last")
def asc_nulls_first(self) -> "Column": # noqa: D102
return self.asc().nulls_first()
def asc_nulls_last(self) -> "Column": # noqa: D102
return self.asc().nulls_last()
def desc_nulls_first(self) -> "Column": # noqa: D102
return self.desc().nulls_first()
def desc_nulls_last(self) -> "Column": # noqa: D102
return self.desc().nulls_last()
def isNull(self) -> "Column": # noqa: D102
return Column(self.expr.isnull())
def isNotNull(self) -> "Column": # noqa: D102
return Column(self.expr.isnotnull())

View File

@@ -0,0 +1,24 @@
from typing import Optional, Union # noqa: D100
from duckdb import DuckDBPyConnection
from duckdb.experimental.spark._globals import _NoValue, _NoValueType
class RuntimeConfig: # noqa: D101
def __init__(self, connection: DuckDBPyConnection) -> None: # noqa: D107
self._connection = connection
def set(self, key: str, value: str) -> None: # noqa: D102
raise NotImplementedError
def isModifiable(self, key: str) -> bool: # noqa: D102
raise NotImplementedError
def unset(self, key: str) -> None: # noqa: D102
raise NotImplementedError
def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue) -> str: # noqa: D102
raise NotImplementedError
__all__ = ["RuntimeConfig"]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,424 @@
# # noqa: D100
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import TYPE_CHECKING, Callable, Union, overload
from ..exception import ContributionsAcceptedError
from .column import Column
from .dataframe import DataFrame
from .functions import _to_column_expr
from .types import NumericType
# Only import symbols needed for type checking if something is type checking
if TYPE_CHECKING:
from ._typing import ColumnOrName
from .session import SparkSession
__all__ = ["GroupedData", "Grouping"]
def _api_internal(self: "GroupedData", name: str, *cols: str) -> DataFrame:
expressions = ",".join(list(cols))
group_by = str(self._grouping) if self._grouping else ""
projections = self._grouping.get_columns()
jdf = self._df.relation.apply(
function_name=name, # aggregate function
function_aggr=expressions, # inputs to aggregate
group_expr=group_by, # groups
projected_columns=projections, # projections
)
return DataFrame(jdf, self.session)
def df_varargs_api(f: Callable[..., DataFrame]) -> Callable[..., DataFrame]:
def _api(self: "GroupedData", *cols: str) -> DataFrame:
name = f.__name__
return _api_internal(self, name, *cols)
_api.__name__ = f.__name__
_api.__doc__ = f.__doc__
return _api
class Grouping: # noqa: D101
def __init__(self, *cols: "ColumnOrName", **kwargs) -> None: # noqa: D107
self._type = ""
self._cols = [_to_column_expr(x) for x in cols]
if "special" in kwargs:
special = kwargs["special"]
accepted_special = ["cube", "rollup"]
assert special in accepted_special
self._type = special
def get_columns(self) -> str: # noqa: D102
columns = ",".join([str(x) for x in self._cols])
return columns
def __str__(self) -> str: # noqa: D105
columns = self.get_columns()
if self._type:
return self._type + "(" + columns + ")"
return columns
class GroupedData:
"""A set of methods for aggregations on a :class:`DataFrame`,
created by :func:`DataFrame.groupBy`.
""" # noqa: D205
def __init__(self, grouping: Grouping, df: DataFrame) -> None: # noqa: D107
self._grouping = grouping
self._df = df
self.session: SparkSession = df.session
def __repr__(self) -> str: # noqa: D105
return str(self._df)
def count(self) -> DataFrame:
"""Counts the number of records for each group.
Examples:
--------
>>> df = spark.createDataFrame(
... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"]
... )
>>> df.show()
+---+-----+
|age| name|
+---+-----+
| 2|Alice|
| 3|Alice|
| 5| Bob|
| 10| Bob|
+---+-----+
Group-by name, and count each group.
>>> df.groupBy(df.name).count().sort("name").show()
+-----+-----+
| name|count|
+-----+-----+
|Alice| 2|
| Bob| 2|
+-----+-----+
"""
return _api_internal(self, "count").withColumnRenamed("count_star()", "count")
@df_varargs_api
def mean(self, *cols: str) -> DataFrame:
"""Computes average values for each numeric columns for each group.
:func:`mean` is an alias for :func:`avg`.
Parameters
----------
cols : str
column names. Non-numeric columns are ignored.
"""
def avg(self, *cols: str) -> DataFrame:
"""Computes average values for each numeric columns for each group.
:func:`mean` is an alias for :func:`avg`.
Parameters
----------
cols : str
column names. Non-numeric columns are ignored.
Examples:
--------
>>> df = spark.createDataFrame(
... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)],
... ["age", "name", "height"],
... )
>>> df.show()
+---+-----+------+
|age| name|height|
+---+-----+------+
| 2|Alice| 80|
| 3|Alice| 100|
| 5| Bob| 120|
| 10| Bob| 140|
+---+-----+------+
Group-by name, and calculate the mean of the age in each group.
>>> df.groupBy("name").avg("age").sort("name").show()
+-----+--------+
| name|avg(age)|
+-----+--------+
|Alice| 2.5|
| Bob| 7.5|
+-----+--------+
Calculate the mean of the age and height in all data.
>>> df.groupBy().avg("age", "height").show()
+--------+-----------+
|avg(age)|avg(height)|
+--------+-----------+
| 5.0| 110.0|
+--------+-----------+
"""
columns = list(cols)
if len(columns) == 0:
schema = self._df.schema
# Take only the numeric types of the relation
columns: list[str] = [x.name for x in schema.fields if isinstance(x.dataType, NumericType)]
return _api_internal(self, "avg", *columns)
@df_varargs_api
def max(self, *cols: str) -> DataFrame:
"""Computes the max value for each numeric columns for each group.
Examples:
--------
>>> df = spark.createDataFrame(
... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)],
... ["age", "name", "height"],
... )
>>> df.show()
+---+-----+------+
|age| name|height|
+---+-----+------+
| 2|Alice| 80|
| 3|Alice| 100|
| 5| Bob| 120|
| 10| Bob| 140|
+---+-----+------+
Group-by name, and calculate the max of the age in each group.
>>> df.groupBy("name").max("age").sort("name").show()
+-----+--------+
| name|max(age)|
+-----+--------+
|Alice| 3|
| Bob| 10|
+-----+--------+
Calculate the max of the age and height in all data.
>>> df.groupBy().max("age", "height").show()
+--------+-----------+
|max(age)|max(height)|
+--------+-----------+
| 10| 140|
+--------+-----------+
"""
@df_varargs_api
def min(self, *cols: str) -> DataFrame:
"""Computes the min value for each numeric column for each group.
Parameters
----------
cols : str
column names. Non-numeric columns are ignored.
Examples:
--------
>>> df = spark.createDataFrame(
... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)],
... ["age", "name", "height"],
... )
>>> df.show()
+---+-----+------+
|age| name|height|
+---+-----+------+
| 2|Alice| 80|
| 3|Alice| 100|
| 5| Bob| 120|
| 10| Bob| 140|
+---+-----+------+
Group-by name, and calculate the min of the age in each group.
>>> df.groupBy("name").min("age").sort("name").show()
+-----+--------+
| name|min(age)|
+-----+--------+
|Alice| 2|
| Bob| 5|
+-----+--------+
Calculate the min of the age and height in all data.
>>> df.groupBy().min("age", "height").show()
+--------+-----------+
|min(age)|min(height)|
+--------+-----------+
| 2| 80|
+--------+-----------+
"""
@df_varargs_api
def sum(self, *cols: str) -> DataFrame:
"""Computes the sum for each numeric columns for each group.
Parameters
----------
cols : str
column names. Non-numeric columns are ignored.
Examples:
--------
>>> df = spark.createDataFrame(
... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)],
... ["age", "name", "height"],
... )
>>> df.show()
+---+-----+------+
|age| name|height|
+---+-----+------+
| 2|Alice| 80|
| 3|Alice| 100|
| 5| Bob| 120|
| 10| Bob| 140|
+---+-----+------+
Group-by name, and calculate the sum of the age in each group.
>>> df.groupBy("name").sum("age").sort("name").show()
+-----+--------+
| name|sum(age)|
+-----+--------+
|Alice| 5|
| Bob| 15|
+-----+--------+
Calculate the sum of the age and height in all data.
>>> df.groupBy().sum("age", "height").show()
+--------+-----------+
|sum(age)|sum(height)|
+--------+-----------+
| 20| 440|
+--------+-----------+
"""
@overload
def agg(self, *exprs: Column) -> DataFrame: ...
@overload
def agg(self, __exprs: dict[str, str]) -> DataFrame: ... # noqa: PYI063
def agg(self, *exprs: Union[Column, dict[str, str]]) -> DataFrame:
"""Compute aggregates and returns the result as a :class:`DataFrame`.
The available aggregate functions can be:
1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count`
2. group aggregate pandas UDFs, created with :func:`pyspark.sql.functions.pandas_udf`
.. note:: There is no partial aggregation with group aggregate UDFs, i.e.,
a full shuffle is required. Also, all the data of a group will be loaded into
memory, so the user should be aware of the potential OOM risk if data is skewed
and certain groups are too large to fit in memory.
.. seealso:: :func:`pyspark.sql.functions.pandas_udf`
If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
is the column to perform aggregation on, and the value is the aggregate function.
Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.
.. versionadded:: 1.3.0
.. versionchanged:: 3.4.0
Supports Spark Connect.
Parameters
----------
exprs : dict
a dict mapping from column name (string) to aggregate functions (string),
or a list of :class:`Column`.
Notes:
-----
Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed
in a single call to this function.
Examples:
--------
>>> from pyspark.sql import functions as F
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"]
... )
>>> df.show()
+---+-----+
|age| name|
+---+-----+
| 2|Alice|
| 3|Alice|
| 5| Bob|
| 10| Bob|
+---+-----+
Group-by name, and count each group.
>>> df.groupBy(df.name)
GroupedData[grouping...: [name...], value: [age: bigint, name: string], type: GroupBy]
>>> df.groupBy(df.name).agg({"*": "count"}).sort("name").show()
+-----+--------+
| name|count(1)|
+-----+--------+
|Alice| 2|
| Bob| 2|
+-----+--------+
Group-by name, and calculate the minimum age.
>>> df.groupBy(df.name).agg(F.min(df.age)).sort("name").show()
+-----+--------+
| name|min(age)|
+-----+--------+
|Alice| 2|
| Bob| 5|
+-----+--------+
Same as above but uses pandas UDF.
>>> @pandas_udf("int", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
... def min_udf(v):
... return v.min()
>>> df.groupBy(df.name).agg(min_udf(df.age)).sort("name").show() # doctest: +SKIP
+-----+------------+
| name|min_udf(age)|
+-----+------------+
|Alice| 2|
| Bob| 5|
+-----+------------+
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
raise ContributionsAcceptedError
else:
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
expressions = list(self._grouping._cols)
expressions.extend([x.expr for x in exprs])
group_by = str(self._grouping)
rel = self._df.relation.select(*expressions, groups=group_by)
return DataFrame(rel, self.session)
# TODO: add 'pivot' # noqa: TD002, TD003

View File

@@ -0,0 +1,435 @@
from typing import TYPE_CHECKING, Optional, Union, cast # noqa: D100
from ..errors import PySparkNotImplementedError, PySparkTypeError
from ..exception import ContributionsAcceptedError
from .types import StructType
PrimitiveType = Union[bool, float, int, str]
OptionalPrimitiveType = Optional[PrimitiveType]
if TYPE_CHECKING:
from duckdb.experimental.spark.sql.dataframe import DataFrame
from duckdb.experimental.spark.sql.session import SparkSession
class DataFrameWriter: # noqa: D101
def __init__(self, dataframe: "DataFrame") -> None: # noqa: D107
self.dataframe = dataframe
def saveAsTable(self, table_name: str) -> None: # noqa: D102
relation = self.dataframe.relation
relation.create(table_name)
def parquet( # noqa: D102
self,
path: str,
mode: Optional[str] = None,
partitionBy: Union[str, list[str], None] = None,
compression: Optional[str] = None,
) -> None:
relation = self.dataframe.relation
if mode:
raise NotImplementedError
if partitionBy:
raise NotImplementedError
relation.write_parquet(path, compression=compression)
def csv( # noqa: D102
self,
path: str,
mode: Optional[str] = None,
compression: Optional[str] = None,
sep: Optional[str] = None,
quote: Optional[str] = None,
escape: Optional[str] = None,
header: Optional[Union[bool, str]] = None,
nullValue: Optional[str] = None,
escapeQuotes: Optional[Union[bool, str]] = None,
quoteAll: Optional[Union[bool, str]] = None,
dateFormat: Optional[str] = None,
timestampFormat: Optional[str] = None,
ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = None,
ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = None,
charToEscapeQuoteEscaping: Optional[str] = None,
encoding: Optional[str] = None,
emptyValue: Optional[str] = None,
lineSep: Optional[str] = None,
) -> None:
if mode not in (None, "overwrite"):
raise NotImplementedError
if escapeQuotes:
raise NotImplementedError
if ignoreLeadingWhiteSpace:
raise NotImplementedError
if ignoreTrailingWhiteSpace:
raise NotImplementedError
if charToEscapeQuoteEscaping:
raise NotImplementedError
if emptyValue:
raise NotImplementedError
if lineSep:
raise NotImplementedError
relation = self.dataframe.relation
relation.write_csv(
path,
sep=sep,
na_rep=nullValue,
quotechar=quote,
compression=compression,
escapechar=escape,
header=header if isinstance(header, bool) else header == "True",
encoding=encoding,
quoting=quoteAll,
date_format=dateFormat,
timestamp_format=timestampFormat,
)
class DataFrameReader: # noqa: D101
def __init__(self, session: "SparkSession") -> None: # noqa: D107
self.session = session
def load( # noqa: D102
self,
path: Optional[Union[str, list[str]]] = None,
format: Optional[str] = None,
schema: Optional[Union[StructType, str]] = None,
**options: OptionalPrimitiveType,
) -> "DataFrame":
from duckdb.experimental.spark.sql.dataframe import DataFrame
if not isinstance(path, str):
raise TypeError
if options:
raise ContributionsAcceptedError
rel = None
if format:
format = format.lower()
if format == "csv" or format == "tsv":
rel = self.session.conn.read_csv(path)
elif format == "json":
rel = self.session.conn.read_json(path)
elif format == "parquet":
rel = self.session.conn.read_parquet(path)
else:
raise ContributionsAcceptedError
else:
rel = self.session.conn.sql(f"select * from {path}")
df = DataFrame(rel, self.session)
if schema:
if not isinstance(schema, StructType):
raise ContributionsAcceptedError
schema = cast("StructType", schema)
types, names = schema.extract_types_and_names()
df = df._cast_types(types)
df = df.toDF(names)
return df
def csv( # noqa: D102
self,
path: Union[str, list[str]],
schema: Optional[Union[StructType, str]] = None,
sep: Optional[str] = None,
encoding: Optional[str] = None,
quote: Optional[str] = None,
escape: Optional[str] = None,
comment: Optional[str] = None,
header: Optional[Union[bool, str]] = None,
inferSchema: Optional[Union[bool, str]] = None,
ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = None,
ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = None,
nullValue: Optional[str] = None,
nanValue: Optional[str] = None,
positiveInf: Optional[str] = None,
negativeInf: Optional[str] = None,
dateFormat: Optional[str] = None,
timestampFormat: Optional[str] = None,
maxColumns: Optional[Union[int, str]] = None,
maxCharsPerColumn: Optional[Union[int, str]] = None,
maxMalformedLogPerPartition: Optional[Union[int, str]] = None,
mode: Optional[str] = None,
columnNameOfCorruptRecord: Optional[str] = None,
multiLine: Optional[Union[bool, str]] = None,
charToEscapeQuoteEscaping: Optional[str] = None,
samplingRatio: Optional[Union[float, str]] = None,
enforceSchema: Optional[Union[bool, str]] = None,
emptyValue: Optional[str] = None,
locale: Optional[str] = None,
lineSep: Optional[str] = None,
pathGlobFilter: Optional[Union[bool, str]] = None,
recursiveFileLookup: Optional[Union[bool, str]] = None,
modifiedBefore: Optional[Union[bool, str]] = None,
modifiedAfter: Optional[Union[bool, str]] = None,
unescapedQuoteHandling: Optional[str] = None,
) -> "DataFrame":
if not isinstance(path, str):
raise NotImplementedError
if schema and not isinstance(schema, StructType):
raise ContributionsAcceptedError
if comment:
raise ContributionsAcceptedError
if inferSchema:
raise ContributionsAcceptedError
if ignoreLeadingWhiteSpace:
raise ContributionsAcceptedError
if ignoreTrailingWhiteSpace:
raise ContributionsAcceptedError
if nanValue:
raise ConnectionAbortedError
if positiveInf:
raise ConnectionAbortedError
if negativeInf:
raise ConnectionAbortedError
if negativeInf:
raise ConnectionAbortedError
if maxColumns:
raise ContributionsAcceptedError
if maxCharsPerColumn:
raise ContributionsAcceptedError
if maxMalformedLogPerPartition:
raise ContributionsAcceptedError
if mode:
raise ContributionsAcceptedError
if columnNameOfCorruptRecord:
raise ContributionsAcceptedError
if multiLine:
raise ContributionsAcceptedError
if charToEscapeQuoteEscaping:
raise ContributionsAcceptedError
if samplingRatio:
raise ContributionsAcceptedError
if enforceSchema:
raise ContributionsAcceptedError
if emptyValue:
raise ContributionsAcceptedError
if locale:
raise ContributionsAcceptedError
if pathGlobFilter:
raise ContributionsAcceptedError
if recursiveFileLookup:
raise ContributionsAcceptedError
if modifiedBefore:
raise ContributionsAcceptedError
if modifiedAfter:
raise ContributionsAcceptedError
if unescapedQuoteHandling:
raise ContributionsAcceptedError
if lineSep:
# We have support for custom newline, just needs to be ported to 'read_csv'
raise NotImplementedError
dtype = None
names = None
if schema:
schema = cast("StructType", schema)
dtype, names = schema.extract_types_and_names()
rel = self.session.conn.read_csv(
path,
header=header if isinstance(header, bool) else header == "True",
sep=sep,
dtype=dtype,
na_values=nullValue,
quotechar=quote,
escapechar=escape,
encoding=encoding,
date_format=dateFormat,
timestamp_format=timestampFormat,
)
from ..sql.dataframe import DataFrame
df = DataFrame(rel, self.session)
if names:
df = df.toDF(*names)
return df
def parquet(self, *paths: str, **options: "OptionalPrimitiveType") -> "DataFrame": # noqa: D102
input = list(paths)
if len(input) != 1:
msg = "Only single paths are supported for now"
raise NotImplementedError(msg)
option_amount = len(options.keys())
if option_amount != 0:
msg = "Options are not supported"
raise ContributionsAcceptedError(msg)
path = input[0]
rel = self.session.conn.read_parquet(path)
from ..sql.dataframe import DataFrame
df = DataFrame(rel, self.session)
return df
def json(
self,
path: Union[str, list[str]],
schema: Optional[Union[StructType, str]] = None,
primitivesAsString: Optional[Union[bool, str]] = None,
prefersDecimal: Optional[Union[bool, str]] = None,
allowComments: Optional[Union[bool, str]] = None,
allowUnquotedFieldNames: Optional[Union[bool, str]] = None,
allowSingleQuotes: Optional[Union[bool, str]] = None,
allowNumericLeadingZero: Optional[Union[bool, str]] = None,
allowBackslashEscapingAnyCharacter: Optional[Union[bool, str]] = None,
mode: Optional[str] = None,
columnNameOfCorruptRecord: Optional[str] = None,
dateFormat: Optional[str] = None,
timestampFormat: Optional[str] = None,
multiLine: Optional[Union[bool, str]] = None,
allowUnquotedControlChars: Optional[Union[bool, str]] = None,
lineSep: Optional[str] = None,
samplingRatio: Optional[Union[float, str]] = None,
dropFieldIfAllNull: Optional[Union[bool, str]] = None,
encoding: Optional[str] = None,
locale: Optional[str] = None,
pathGlobFilter: Optional[Union[bool, str]] = None,
recursiveFileLookup: Optional[Union[bool, str]] = None,
modifiedBefore: Optional[Union[bool, str]] = None,
modifiedAfter: Optional[Union[bool, str]] = None,
allowNonNumericNumbers: Optional[Union[bool, str]] = None,
) -> "DataFrame":
"""Loads JSON files and returns the results as a :class:`DataFrame`.
`JSON Lines <http://jsonlines.org/>`_ (newline-delimited JSON) is supported by default.
For JSON (one record per file), set the ``multiLine`` parameter to ``true``.
If the ``schema`` parameter is not specified, this function goes
through the input once to determine the input schema.
.. versionadded:: 1.4.0
.. versionchanged:: 3.4.0
Supports Spark Connect.
Parameters
----------
path : str, list or :class:`RDD`
string represents path to the JSON dataset, or a list of paths,
or RDD of Strings storing JSON objects.
schema : :class:`pyspark.sql.types.StructType` or str, optional
an optional :class:`pyspark.sql.types.StructType` for the input schema or
a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
Other Parameters
----------------
Extra options
For the extra options, refer to
`Data Source Option <https://spark.apache.org/docs/latest/sql-data-sources-json.html#data-source-option>`_
for the version you use.
.. # noqa
Examples:
--------
Write a DataFrame into a JSON file and read it back.
>>> import tempfile
>>> with tempfile.TemporaryDirectory() as d:
... # Write a DataFrame into a JSON file
... spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}]).write.mode(
... "overwrite"
... ).format("json").save(d)
...
... # Read the JSON file as a DataFrame.
... spark.read.json(d).show()
+---+------------+
|age| name|
+---+------------+
|100|Hyukjin Kwon|
+---+------------+
"""
if schema is not None:
msg = "The 'schema' option is not supported"
raise ContributionsAcceptedError(msg)
if primitivesAsString is not None:
msg = "The 'primitivesAsString' option is not supported"
raise ContributionsAcceptedError(msg)
if prefersDecimal is not None:
msg = "The 'prefersDecimal' option is not supported"
raise ContributionsAcceptedError(msg)
if allowComments is not None:
msg = "The 'allowComments' option is not supported"
raise ContributionsAcceptedError(msg)
if allowUnquotedFieldNames is not None:
msg = "The 'allowUnquotedFieldNames' option is not supported"
raise ContributionsAcceptedError(msg)
if allowSingleQuotes is not None:
msg = "The 'allowSingleQuotes' option is not supported"
raise ContributionsAcceptedError(msg)
if allowNumericLeadingZero is not None:
msg = "The 'allowNumericLeadingZero' option is not supported"
raise ContributionsAcceptedError(msg)
if allowBackslashEscapingAnyCharacter is not None:
msg = "The 'allowBackslashEscapingAnyCharacter' option is not supported"
raise ContributionsAcceptedError(msg)
if mode is not None:
msg = "The 'mode' option is not supported"
raise ContributionsAcceptedError(msg)
if columnNameOfCorruptRecord is not None:
msg = "The 'columnNameOfCorruptRecord' option is not supported"
raise ContributionsAcceptedError(msg)
if dateFormat is not None:
msg = "The 'dateFormat' option is not supported"
raise ContributionsAcceptedError(msg)
if timestampFormat is not None:
msg = "The 'timestampFormat' option is not supported"
raise ContributionsAcceptedError(msg)
if multiLine is not None:
msg = "The 'multiLine' option is not supported"
raise ContributionsAcceptedError(msg)
if allowUnquotedControlChars is not None:
msg = "The 'allowUnquotedControlChars' option is not supported"
raise ContributionsAcceptedError(msg)
if lineSep is not None:
msg = "The 'lineSep' option is not supported"
raise ContributionsAcceptedError(msg)
if samplingRatio is not None:
msg = "The 'samplingRatio' option is not supported"
raise ContributionsAcceptedError(msg)
if dropFieldIfAllNull is not None:
msg = "The 'dropFieldIfAllNull' option is not supported"
raise ContributionsAcceptedError(msg)
if encoding is not None:
msg = "The 'encoding' option is not supported"
raise ContributionsAcceptedError(msg)
if locale is not None:
msg = "The 'locale' option is not supported"
raise ContributionsAcceptedError(msg)
if pathGlobFilter is not None:
msg = "The 'pathGlobFilter' option is not supported"
raise ContributionsAcceptedError(msg)
if recursiveFileLookup is not None:
msg = "The 'recursiveFileLookup' option is not supported"
raise ContributionsAcceptedError(msg)
if modifiedBefore is not None:
msg = "The 'modifiedBefore' option is not supported"
raise ContributionsAcceptedError(msg)
if modifiedAfter is not None:
msg = "The 'modifiedAfter' option is not supported"
raise ContributionsAcceptedError(msg)
if allowNonNumericNumbers is not None:
msg = "The 'allowNonNumericNumbers' option is not supported"
raise ContributionsAcceptedError(msg)
if isinstance(path, str):
path = [path]
if isinstance(path, list):
if len(path) == 1:
rel = self.session.conn.read_json(path[0])
from .dataframe import DataFrame
df = DataFrame(rel, self.session)
return df
raise PySparkNotImplementedError(message="Only a single path is supported for now")
else:
raise PySparkTypeError(
error_class="NOT_STR_OR_LIST_OF_RDD",
message_parameters={
"arg_name": "path",
"arg_type": type(path).__name__,
},
)
__all__ = ["DataFrameReader", "DataFrameWriter"]

View File

@@ -0,0 +1,297 @@
import uuid # noqa: D100
from collections.abc import Iterable, Sized
from typing import TYPE_CHECKING, Any, NoReturn, Optional, Union
import duckdb
if TYPE_CHECKING:
from pandas.core.frame import DataFrame as PandasDataFrame
from .catalog import Catalog
from ..conf import SparkConf
from ..context import SparkContext
from ..errors import PySparkTypeError
from ..exception import ContributionsAcceptedError
from .conf import RuntimeConfig
from .dataframe import DataFrame
from .readwriter import DataFrameReader
from .streaming import DataStreamReader
from .types import StructType
from .udf import UDFRegistration
# In spark:
# SparkSession holds a SparkContext
# SparkContext gets created from SparkConf
# At this level the check is made to determine whether the instance already exists and just needs
# to be retrieved or it needs to be created.
# For us this is done inside of `duckdb.connect`, based on the passed in path + configuration
# SparkContext can be compared to our Connection class, and SparkConf to our ClientContext class
# data is a List of rows
# every value in each row needs to be turned into a Value
def _combine_data_and_schema(data: Iterable[Any], schema: StructType) -> list[duckdb.Value]:
from duckdb import Value
new_data = []
for row in data:
new_row = [Value(x, dtype.duckdb_type) for x, dtype in zip(row, [y.dataType for y in schema])]
new_data.append(new_row)
return new_data
class SparkSession: # noqa: D101
def __init__(self, context: SparkContext) -> None: # noqa: D107
self.conn = context.connection
self._context = context
self._conf = RuntimeConfig(self.conn)
def _create_dataframe(self, data: Union[Iterable[Any], "PandasDataFrame"]) -> DataFrame:
try:
import pandas
has_pandas = True
except ImportError:
has_pandas = False
if has_pandas and isinstance(data, pandas.DataFrame):
unique_name = f"pyspark_pandas_df_{uuid.uuid1()}"
self.conn.register(unique_name, data)
return DataFrame(self.conn.sql(f'select * from "{unique_name}"'), self)
def verify_tuple_integrity(tuples: list[tuple]) -> None:
if len(tuples) <= 1:
return
expected_length = len(tuples[0])
for i, item in enumerate(tuples[1:]):
actual_length = len(item)
if expected_length == actual_length:
continue
raise PySparkTypeError(
error_class="LENGTH_SHOULD_BE_THE_SAME",
message_parameters={
"arg1": f"data{i}",
"arg2": f"data{i + 1}",
"arg1_length": str(expected_length),
"arg2_length": str(actual_length),
},
)
if not isinstance(data, list):
data = list(data)
verify_tuple_integrity(data)
def construct_query(tuples: Iterable) -> str:
def construct_values_list(row: Sized, start_param_idx: int) -> str:
parameter_count = len(row)
parameters = [f"${x + start_param_idx}" for x in range(parameter_count)]
parameters = "(" + ", ".join(parameters) + ")"
return parameters
row_size = len(tuples[0])
values_list = [construct_values_list(x, 1 + (i * row_size)) for i, x in enumerate(tuples)]
values_list = ", ".join(values_list)
query = f"""
select * from (values {values_list})
"""
return query
query = construct_query(data)
def construct_parameters(tuples: Iterable) -> list[list]:
parameters = []
for row in tuples:
parameters.extend(list(row))
return parameters
parameters = construct_parameters(data)
rel = self.conn.sql(query, params=parameters)
return DataFrame(rel, self)
def _createDataFrameFromPandas(
self, data: "PandasDataFrame", types: Union[list[str], None], names: Union[list[str], None]
) -> DataFrame:
df = self._create_dataframe(data)
# Cast to types
if types:
df = df._cast_types(*types)
# Alias to names
if names:
df = df.toDF(*names)
return df
def createDataFrame( # noqa: D102
self,
data: Union["PandasDataFrame", Iterable[Any]],
schema: Optional[Union[StructType, list[str]]] = None,
samplingRatio: Optional[float] = None,
verifySchema: bool = True,
) -> DataFrame:
if samplingRatio:
raise NotImplementedError
if not verifySchema:
raise NotImplementedError
types = None
names = None
if isinstance(data, DataFrame):
raise PySparkTypeError(
error_class="SHOULD_NOT_DATAFRAME",
message_parameters={"arg_name": "data"},
)
if schema:
if isinstance(schema, StructType):
types, names = schema.extract_types_and_names()
else:
names = schema
try:
import pandas
has_pandas = True
except ImportError:
has_pandas = False
# Falsey check on pandas dataframe is not defined, so first check if it's not a pandas dataframe
# Then check if 'data' is None or []
if has_pandas and isinstance(data, pandas.DataFrame):
return self._createDataFrameFromPandas(data, types, names)
# Finally check if a schema was provided
is_empty = False
if not data and names:
# Create NULLs for every type in our dataframe
is_empty = True
data = [tuple(None for _ in names)]
if schema and isinstance(schema, StructType):
# Transform the data into Values to combine the data+schema
data = _combine_data_and_schema(data, schema)
df = self._create_dataframe(data)
if is_empty:
rel = df.relation
# Add impossible where clause
rel = rel.filter("1=0")
df = DataFrame(rel, self)
# Cast to types
if types:
df = df._cast_types(*types)
# Alias to names
if names:
df = df.toDF(*names)
return df
def newSession(self) -> "SparkSession": # noqa: D102
return SparkSession(self._context)
def range( # noqa: D102
self,
start: int,
end: Optional[int] = None,
step: int = 1,
numPartitions: Optional[int] = None,
) -> "DataFrame":
if numPartitions:
raise ContributionsAcceptedError
if end is None:
end = start
start = 0
return DataFrame(self.conn.table_function("range", parameters=[start, end, step]), self)
def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame: # noqa: D102, ANN401
if kwargs:
raise NotImplementedError
relation = self.conn.sql(sqlQuery)
return DataFrame(relation, self)
def stop(self) -> None: # noqa: D102
self._context.stop()
def table(self, tableName: str) -> DataFrame: # noqa: D102
relation = self.conn.table(tableName)
return DataFrame(relation, self)
def getActiveSession(self) -> "SparkSession": # noqa: D102
return self
@property
def catalog(self) -> "Catalog": # noqa: D102
if not hasattr(self, "_catalog"):
from duckdb.experimental.spark.sql.catalog import Catalog
self._catalog = Catalog(self)
return self._catalog
@property
def conf(self) -> RuntimeConfig: # noqa: D102
return self._conf
@property
def read(self) -> DataFrameReader: # noqa: D102
return DataFrameReader(self)
@property
def readStream(self) -> DataStreamReader: # noqa: D102
return DataStreamReader(self)
@property
def sparkContext(self) -> SparkContext: # noqa: D102
return self._context
@property
def streams(self) -> NoReturn: # noqa: D102
raise ContributionsAcceptedError
@property
def udf(self) -> UDFRegistration: # noqa: D102
return UDFRegistration(self)
@property
def version(self) -> str: # noqa: D102
return "1.0.0"
class Builder: # noqa: D106
def __init__(self) -> None: # noqa: D107
pass
def master(self, name: str) -> "SparkSession.Builder": # noqa: D102
# no-op
return self
def appName(self, name: str) -> "SparkSession.Builder": # noqa: D102
# no-op
return self
def remote(self, url: str) -> "SparkSession.Builder": # noqa: D102
# no-op
return self
def getOrCreate(self) -> "SparkSession": # noqa: D102
context = SparkContext("__ignored__")
return SparkSession(context)
def config( # noqa: D102
self,
key: Optional[str] = None,
value: Optional[Any] = None, # noqa: ANN401
conf: Optional[SparkConf] = None,
) -> "SparkSession.Builder":
return self
def enableHiveSupport(self) -> "SparkSession.Builder": # noqa: D102
# no-op
return self
builder = Builder()
__all__ = ["SparkSession"]

View File

@@ -0,0 +1,36 @@
from typing import TYPE_CHECKING, Optional, Union # noqa: D100
from .types import StructType
if TYPE_CHECKING:
from .dataframe import DataFrame
from .session import SparkSession
PrimitiveType = Union[bool, float, int, str]
OptionalPrimitiveType = Optional[PrimitiveType]
class DataStreamWriter: # noqa: D101
def __init__(self, dataframe: "DataFrame") -> None: # noqa: D107
self.dataframe = dataframe
def toTable(self, table_name: str) -> None: # noqa: D102
# Should we register the dataframe or create a table from the contents?
raise NotImplementedError
class DataStreamReader: # noqa: D101
def __init__(self, session: "SparkSession") -> None: # noqa: D107
self.session = session
def load( # noqa: D102
self,
path: Optional[str] = None,
format: Optional[str] = None,
schema: Union[StructType, str, None] = None,
**options: OptionalPrimitiveType,
) -> "DataFrame":
raise NotImplementedError
__all__ = ["DataStreamReader", "DataStreamWriter"]

View File

@@ -0,0 +1,113 @@
from typing import cast # noqa: D100
from duckdb.sqltypes import DuckDBPyType
from ..exception import ContributionsAcceptedError
from .types import (
ArrayType,
BinaryType,
BitstringType,
BooleanType,
ByteType,
DataType,
DateType,
DayTimeIntervalType,
DecimalType,
DoubleType,
FloatType,
HugeIntegerType,
IntegerType,
LongType,
MapType,
ShortType,
StringType,
StructField,
StructType,
TimeNTZType,
TimestampMilisecondNTZType,
TimestampNanosecondNTZType,
TimestampNTZType,
TimestampSecondNTZType,
TimestampType,
TimeType,
UnsignedByteType,
UnsignedHugeIntegerType,
UnsignedIntegerType,
UnsignedLongType,
UnsignedShortType,
UUIDType,
)
_sqltype_to_spark_class = {
"boolean": BooleanType,
"utinyint": UnsignedByteType,
"tinyint": ByteType,
"usmallint": UnsignedShortType,
"smallint": ShortType,
"uinteger": UnsignedIntegerType,
"integer": IntegerType,
"ubigint": UnsignedLongType,
"bigint": LongType,
"hugeint": HugeIntegerType,
"uhugeint": UnsignedHugeIntegerType,
"varchar": StringType,
"blob": BinaryType,
"bit": BitstringType,
"uuid": UUIDType,
"date": DateType,
"time": TimeNTZType,
"time with time zone": TimeType,
"timestamp": TimestampNTZType,
"timestamp with time zone": TimestampType,
"timestamp_ms": TimestampNanosecondNTZType,
"timestamp_ns": TimestampMilisecondNTZType,
"timestamp_s": TimestampSecondNTZType,
"interval": DayTimeIntervalType,
"list": ArrayType,
"struct": StructType,
"map": MapType,
# union
# enum
# null (???)
"float": FloatType,
"double": DoubleType,
"decimal": DecimalType,
}
def convert_nested_type(dtype: DuckDBPyType) -> DataType: # noqa: D103
id = dtype.id
if id == "list" or id == "array":
children = dtype.children
return ArrayType(convert_type(children[0][1]))
if id == "union":
msg = (
"Union types are not supported in the PySpark interface. "
"DuckDB union types cannot be directly mapped to PySpark types."
)
raise ContributionsAcceptedError(msg)
if id == "struct":
children: list[tuple[str, DuckDBPyType]] = dtype.children
fields = [StructField(x[0], convert_type(x[1])) for x in children]
return StructType(fields)
if id == "map":
return MapType(convert_type(dtype.key), convert_type(dtype.value))
raise NotImplementedError
def convert_type(dtype: DuckDBPyType) -> DataType: # noqa: D103
id = dtype.id
if id in ["list", "struct", "map", "array"]:
return convert_nested_type(dtype)
if id == "decimal":
children: list[tuple[str, DuckDBPyType]] = dtype.children
precision = cast("int", children[0][1])
scale = cast("int", children[1][1])
return DecimalType(precision, scale)
spark_type = _sqltype_to_spark_class[id]
return spark_type()
def duckdb_to_spark_schema(names: list[str], types: list[DuckDBPyType]) -> StructType: # noqa: D103
fields = [StructField(name, dtype) for name, dtype in zip(names, [convert_type(x) for x in types])]
return StructType(fields)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,37 @@
# https://sparkbyexamples.com/pyspark/pyspark-udf-user-defined-function/ # noqa: D100
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
from .types import DataType
if TYPE_CHECKING:
from .session import SparkSession
DataTypeOrString = Union[DataType, str]
UserDefinedFunctionLike = TypeVar("UserDefinedFunctionLike")
class UDFRegistration: # noqa: D101
def __init__(self, sparkSession: "SparkSession") -> None: # noqa: D107
self.sparkSession = sparkSession
def register( # noqa: D102
self,
name: str,
f: Union[Callable[..., Any], "UserDefinedFunctionLike"],
returnType: Optional["DataTypeOrString"] = None,
) -> "UserDefinedFunctionLike":
self.sparkSession.conn.create_function(name, f, return_type=returnType)
def registerJavaFunction( # noqa: D102
self,
name: str,
javaClassName: str,
returnType: Optional["DataTypeOrString"] = None,
) -> None:
raise NotImplementedError
def registerJavaUDAF(self, name: str, javaClassName: str) -> None: # noqa: D102
raise NotImplementedError
__all__ = ["UDFRegistration"]

View File

@@ -0,0 +1,33 @@
"""In-memory filesystem to store ephemeral dependencies.
Warning: Not for external use. May change at any moment. Likely to be made internal.
"""
from __future__ import annotations
import io
import typing
from fsspec import AbstractFileSystem
from fsspec.implementations.memory import MemoryFile, MemoryFileSystem
from .bytes_io_wrapper import BytesIOWrapper
class ModifiedMemoryFileSystem(MemoryFileSystem):
"""In-memory filesystem implementation that uses its own protocol."""
protocol = ("DUCKDB_INTERNAL_OBJECTSTORE",)
# defer to the original implementation that doesn't hardcode the protocol
_strip_protocol: typing.Callable[[str], str] = classmethod(AbstractFileSystem._strip_protocol.__func__) # type: ignore[assignment]
def add_file(self, obj: io.IOBase | BytesIOWrapper | object, path: str) -> None:
"""Add a file to the filesystem."""
if not (hasattr(obj, "read") and hasattr(obj, "seek")):
msg = "Can not read from a non file-like object"
raise TypeError(msg)
if isinstance(obj, io.TextIOBase):
# Wrap this so that we can return a bytes object from 'read'
obj = BytesIOWrapper(obj)
path = self._strip_protocol(path)
self.store[path] = MemoryFile(self, path, obj.read())

View File

@@ -0,0 +1,3 @@
from _duckdb._func import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonUDFType # noqa: D104
__all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonUDFType"]

View File

@@ -0,0 +1,13 @@
"""DuckDB function constants and types. DEPRECATED: please use `duckdb.func` instead."""
import warnings
from duckdb.func import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonUDFType
__all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonUDFType"]
warnings.warn(
"`duckdb.functional` is deprecated and will be removed in a future version. Please use `duckdb.func` instead.",
DeprecationWarning,
stacklevel=2,
)

View File

@@ -0,0 +1,285 @@
from __future__ import annotations # noqa: D100
import contextlib
import datetime
import json
import typing
from decimal import Decimal
import polars as pl
from polars.io.plugins import register_io_source
import duckdb
if typing.TYPE_CHECKING:
from collections.abc import Iterator
import typing_extensions
_ExpressionTree: typing_extensions.TypeAlias = typing.Dict[str, typing.Union[str, int, "_ExpressionTree", typing.Any]] # noqa: UP006
def _predicate_to_expression(predicate: pl.Expr) -> duckdb.Expression | None:
"""Convert a Polars predicate expression to a DuckDB-compatible SQL expression.
Parameters:
predicate (pl.Expr): A Polars expression (e.g., col("foo") > 5)
Returns:
SQLExpression: A DuckDB SQL expression string equivalent.
None: If conversion fails.
Example:
>>> _predicate_to_expression(pl.col("foo") > 5)
SQLExpression("(foo > 5)")
"""
# Serialize the Polars expression tree to JSON
tree = json.loads(predicate.meta.serialize(format="json"))
try:
# Convert the tree to SQL
sql_filter = _pl_tree_to_sql(tree)
return duckdb.SQLExpression(sql_filter)
except Exception:
# If the conversion fails, we return None
return None
def _pl_operation_to_sql(op: str) -> str:
"""Map Polars binary operation strings to SQL equivalents.
Example:
>>> _pl_operation_to_sql("Eq")
'='
"""
try:
return {
"Lt": "<",
"LtEq": "<=",
"Gt": ">",
"GtEq": ">=",
"Eq": "=",
"Modulus": "%",
"And": "AND",
"Or": "OR",
}[op]
except KeyError:
raise NotImplementedError(op) # noqa: B904
def _escape_sql_identifier(identifier: str) -> str:
"""Escape SQL identifiers by doubling any double quotes and wrapping in double quotes.
Example:
>>> _escape_sql_identifier('column"name')
'"column""name"'
"""
escaped = identifier.replace('"', '""')
return f'"{escaped}"'
def _pl_tree_to_sql(tree: _ExpressionTree) -> str:
"""Recursively convert a Polars expression tree (as JSON) to a SQL string.
Parameters:
tree (dict): JSON-deserialized expression tree from Polars
Returns:
str: SQL expression string
Example:
Input tree:
{
"BinaryExpr": {
"left": { "Column": "foo" },
"op": "Gt",
"right": { "Literal": { "Int": 5 } }
}
}
Output: "(foo > 5)"
"""
[node_type] = tree.keys()
if node_type == "BinaryExpr":
# Binary expressions: left OP right
bin_expr_tree = tree[node_type]
assert isinstance(bin_expr_tree, dict), f"A {node_type} should be a dict but got {type(bin_expr_tree)}"
lhs, op, rhs = bin_expr_tree["left"], bin_expr_tree["op"], bin_expr_tree["right"]
assert isinstance(lhs, dict), f"LHS of a {node_type} should be a dict but got {type(lhs)}"
assert isinstance(op, str), f"The op of a {node_type} should be a str but got {type(op)}"
assert isinstance(rhs, dict), f"RHS of a {node_type} should be a dict but got {type(rhs)}"
return f"({_pl_tree_to_sql(lhs)} {_pl_operation_to_sql(op)} {_pl_tree_to_sql(rhs)})"
if node_type == "Column":
# A reference to a column name
# Wrap in quotes to handle special characters
col_name = tree[node_type]
assert isinstance(col_name, str), f"The col name of a {node_type} should be a str but got {type(col_name)}"
return _escape_sql_identifier(col_name)
if node_type in ("Literal", "Dyn"):
# Recursively process dynamic or literal values
val_tree = tree[node_type]
assert isinstance(val_tree, dict), f"A {node_type} should be a dict but got {type(val_tree)}"
return _pl_tree_to_sql(val_tree)
if node_type == "Int":
# Direct integer literals
int_literal = tree[node_type]
assert isinstance(int_literal, (int, str)), (
f"The value of an Int should be an int or str but got {type(int_literal)}"
)
return str(int_literal)
if node_type == "Function":
# Handle boolean functions like IsNull, IsNotNull
func_tree = tree[node_type]
assert isinstance(func_tree, dict), f"A {node_type} should be a dict but got {type(func_tree)}"
inputs = func_tree["input"]
assert isinstance(inputs, list), f"A {node_type} should have a list of dicts as input but got {type(inputs)}"
input_tree = inputs[0]
assert isinstance(input_tree, dict), (
f"A {node_type} should have a list of dicts as input but got {type(input_tree)}"
)
func_dict = func_tree["function"]
assert isinstance(func_dict, dict), (
f"A {node_type} should have a function dict as input but got {type(func_dict)}"
)
if "Boolean" in func_dict:
func = func_dict["Boolean"]
arg_sql = _pl_tree_to_sql(inputs[0])
if func == "IsNull":
return f"({arg_sql} IS NULL)"
if func == "IsNotNull":
return f"({arg_sql} IS NOT NULL)"
msg = f"Boolean function not supported: {func}"
raise NotImplementedError(msg)
msg = f"Unsupported function type: {func_dict}"
raise NotImplementedError(msg)
if node_type == "Scalar":
# Detect format: old style (dtype/value) or new style (direct type key)
scalar_tree = tree[node_type]
assert isinstance(scalar_tree, dict), f"A {node_type} should be a dict but got {type(scalar_tree)}"
if "dtype" in scalar_tree and "value" in scalar_tree:
dtype = str(scalar_tree["dtype"])
value = scalar_tree["value"]
else:
# New style: dtype is the single key in the dict
dtype = next(iter(scalar_tree.keys()))
value = scalar_tree
assert isinstance(dtype, str), f"A {node_type} should have a str dtype but got {type(dtype)}"
assert isinstance(value, dict), f"A {node_type} should have a dict value but got {type(value)}"
# Decimal support
if dtype.startswith("{'Decimal'") or dtype == "Decimal":
decimal_value = value["Decimal"]
assert isinstance(decimal_value, list), (
f"A {dtype} should be a two or three member list but got {type(decimal_value)}"
)
assert 2 <= len(decimal_value) <= 3, (
f"A {dtype} should be a two or three member list but got {len(decimal_value)} member list"
)
return str(Decimal(decimal_value[0]) / Decimal(10 ** decimal_value[-1]))
# Datetime with microseconds since epoch
if dtype.startswith("{'Datetime'") or dtype == "Datetime":
micros = value["Datetime"]
assert isinstance(micros, list), f"A {dtype} should be a one member list but got {type(micros)}"
dt_timestamp = datetime.datetime.fromtimestamp(micros[0] / 1_000_000, tz=datetime.timezone.utc)
return f"'{dt_timestamp!s}'::TIMESTAMP"
# Match simple numeric/boolean types
if dtype in (
"Int8",
"Int16",
"Int32",
"Int64",
"UInt8",
"UInt16",
"UInt32",
"UInt64",
"Float32",
"Float64",
"Boolean",
):
return str(value[dtype])
# Time type
if dtype == "Time":
nanoseconds = value["Time"]
assert isinstance(nanoseconds, int), f"A {dtype} should be an int but got {type(nanoseconds)}"
seconds = nanoseconds // 1_000_000_000
microseconds = (nanoseconds % 1_000_000_000) // 1_000
dt_time = (datetime.datetime.min + datetime.timedelta(seconds=seconds, microseconds=microseconds)).time()
return f"'{dt_time}'::TIME"
# Date type
if dtype == "Date":
days_since_epoch = value["Date"]
assert isinstance(days_since_epoch, (float, int)), (
f"A {dtype} should be a number but got {type(days_since_epoch)}"
)
date = datetime.date(1970, 1, 1) + datetime.timedelta(days=days_since_epoch)
return f"'{date}'::DATE"
# Binary type
if dtype == "Binary":
bin_value = value["Binary"]
assert isinstance(bin_value, list), f"A {dtype} should be a list but got {type(bin_value)}"
binary_data = bytes(bin_value)
escaped = "".join(f"\\x{b:02x}" for b in binary_data)
return f"'{escaped}'::BLOB"
# String type
if dtype == "String" or dtype == "StringOwned":
# Some new formats may store directly under StringOwned
string_val = value.get("StringOwned", value.get("String", None))
# the string must be a string constant
return str(duckdb.ConstantExpression(string_val))
msg = f"Unsupported scalar type {dtype!s}, with value {value}"
raise NotImplementedError(msg)
msg = f"Node type: {node_type} is not implemented. {tree[node_type]}"
raise NotImplementedError(msg)
def duckdb_source(relation: duckdb.DuckDBPyRelation, schema: pl.schema.Schema) -> pl.LazyFrame:
"""A polars IO plugin for DuckDB."""
def source_generator(
with_columns: list[str] | None,
predicate: pl.Expr | None,
n_rows: int | None,
batch_size: int | None,
) -> Iterator[pl.DataFrame]:
duck_predicate = None
relation_final = relation
if with_columns is not None:
cols = ",".join(map(_escape_sql_identifier, with_columns))
relation_final = relation_final.project(cols)
if n_rows is not None:
relation_final = relation_final.limit(n_rows)
if predicate is not None:
# We have a predicate, if possible, we push it down to DuckDB
with contextlib.suppress(AssertionError, KeyError):
duck_predicate = _predicate_to_expression(predicate)
# Try to pushdown filter, if one exists
if duck_predicate is not None:
relation_final = relation_final.filter(duck_predicate)
if batch_size is None:
results = relation_final.fetch_arrow_reader()
else:
results = relation_final.fetch_arrow_reader(batch_size)
for record_batch in iter(results.read_next_batch, None):
if predicate is not None and duck_predicate is None:
# We have a predicate, but did not manage to push it down, we fallback here
yield pl.from_arrow(record_batch).filter(predicate) # type: ignore[arg-type,misc,unused-ignore]
else:
yield pl.from_arrow(record_batch) # type: ignore[misc,unused-ignore]
return register_io_source(source_generator, schema=schema)

View File

View File

@@ -0,0 +1,358 @@
import argparse # noqa: D100
import json
import re
import webbrowser
from functools import reduce
from pathlib import Path
qgraph_css = """
.styled-table {
border-collapse: collapse;
margin: 25px 0;
font-size: 0.9em;
font-family: sans-serif;
min-width: 400px;
box-shadow: 0 0 20px rgba(0, 0, 0, 0.15);
}
.styled-table thead tr {
background-color: #009879;
color: #ffffff;
text-align: left;
}
.styled-table th,
.styled-table td {
padding: 12px 15px;
}
.styled-table tbody tr {
border-bottom: 1px solid #dddddd;
}
.styled-table tbody tr:nth-of-type(even) {
background-color: #f3f3f3;
}
.styled-table tbody tr:last-of-type {
border-bottom: 2px solid #009879;
}
.node-body {
font-size:15px;
}
.tf-nc {
position: relative;
width: 180px;
text-align: center;
background-color: #fff100;
}
.custom-tooltip {
position: relative;
display: inline-block;
}
.tooltip-text {
visibility: hidden;
background-color: #333;
color: #fff;
text-align: center;
padding: 0px;
border-radius: 1px;
/* Positioning */
position: absolute;
z-index: 1;
bottom: 100%;
left: 50%;
transform: translateX(-50%);
margin-bottom: 8px;
/* Tooltip Arrow */
width: 400px;
}
.custom-tooltip:hover .tooltip-text {
visibility: visible;
}
"""
class NodeTiming: # noqa: D101
def __init__(self, phase: str, time: float) -> None: # noqa: D107
self.phase = phase
self.time = time
# percentage is determined later.
self.percentage = 0
def calculate_percentage(self, total_time: float) -> None: # noqa: D102
self.percentage = self.time / total_time
def combine_timing(self, r: "NodeTiming") -> "NodeTiming": # noqa: D102
# TODO: can only add timings for same-phase nodes # noqa: TD002, TD003
total_time = self.time + r.time
return NodeTiming(self.phase, total_time)
class AllTimings: # noqa: D101
def __init__(self) -> None: # noqa: D107
self.phase_to_timings = {}
def add_node_timing(self, node_timing: NodeTiming) -> None: # noqa: D102
if node_timing.phase in self.phase_to_timings:
self.phase_to_timings[node_timing.phase].append(node_timing)
else:
self.phase_to_timings[node_timing.phase] = [node_timing]
def get_phase_timings(self, phase: str) -> list[NodeTiming]: # noqa: D102
return self.phase_to_timings[phase]
def get_summary_phase_timings(self, phase: str) -> NodeTiming: # noqa: D102
return reduce(NodeTiming.combine_timing, self.phase_to_timings[phase])
def get_phases(self) -> list[NodeTiming]: # noqa: D102
phases = list(self.phase_to_timings.keys())
phases.sort(key=lambda x: (self.get_summary_phase_timings(x)).time)
phases.reverse()
return phases
def get_sum_of_all_timings(self) -> float: # noqa: D102
total_timing_sum = 0
for phase in self.phase_to_timings:
total_timing_sum += self.get_summary_phase_timings(phase).time
return total_timing_sum
def open_utf8(fpath: str, flags: str) -> object: # noqa: D103
return Path(fpath).open(mode=flags, encoding="utf8")
def get_child_timings(top_node: object, query_timings: object) -> str: # noqa: D103
node_timing = NodeTiming(top_node["operator_type"], float(top_node["operator_timing"]))
query_timings.add_node_timing(node_timing)
for child in top_node["children"]:
get_child_timings(child, query_timings)
def get_pink_shade_hex(fraction: float) -> str: # noqa: D103
fraction = max(0, min(1, fraction))
# Define the RGB values for very light pink (almost white) and dark pink
light_pink = (255, 250, 250) # Very light pink
dark_pink = (255, 20, 147) # Dark pink
# Calculate the RGB values for the given fraction
r = int(light_pink[0] + (dark_pink[0] - light_pink[0]) * fraction)
g = int(light_pink[1] + (dark_pink[1] - light_pink[1]) * fraction)
b = int(light_pink[2] + (dark_pink[2] - light_pink[2]) * fraction)
# Return as hexadecimal color code
return f"#{r:02x}{g:02x}{b:02x}"
def get_node_body(name: str, result: str, cpu_time: float, card: int, est: int, width: int, extra_info: str) -> str: # noqa: D103
node_style = f"background-color: {get_pink_shade_hex(float(result) / cpu_time)};"
body = f'<span class="tf-nc custom-tooltip" style="{node_style}">'
body += '<div class="node-body">'
new_name = "BRIDGE" if (name == "INVALID") else name.replace("_", " ")
formatted_num = f"{float(result):.4f}"
body += f"<p><b>{new_name}</b> </p><p>time: {formatted_num} seconds</p>"
body += f'<span class="tooltip-text"> {extra_info} </span>'
if width > 0:
body += f"<p>cardinality: {card}</p>"
body += f"<p>estimate: {est}</p>"
body += f"<p>width: {width} bytes</p>"
# TODO: Expand on timing. Usually available from a detailed profiling # noqa: TD002, TD003
body += "</div>"
body += "</span>"
return body
def generate_tree_recursive(json_graph: object, cpu_time: float) -> str: # noqa: D103
node_prefix_html = "<li>"
node_suffix_html = "</li>"
extra_info = ""
estimate = 0
for key in json_graph["extra_info"]:
value = json_graph["extra_info"][key]
if key == "Estimated Cardinality":
estimate = int(value)
else:
extra_info += f"{key}: {value} <br>"
cardinality = json_graph["operator_cardinality"]
width = int(json_graph["result_set_size"] / max(1, cardinality))
# get rid of some typically long names
extra_info = re.sub(r"__internal_\s*", "__", extra_info)
extra_info = re.sub(r"compress_integral\s*", "compress", extra_info)
node_body = get_node_body(
json_graph["operator_type"],
json_graph["operator_timing"],
cpu_time,
cardinality,
estimate,
width,
re.sub(r",\s*", ", ", extra_info),
)
children_html = ""
if len(json_graph["children"]) >= 1:
children_html += "<ul>"
for child in json_graph["children"]:
children_html += generate_tree_recursive(child, cpu_time)
children_html += "</ul>"
return node_prefix_html + node_body + children_html + node_suffix_html
# For generating the table in the top left.
def generate_timing_html(graph_json: object, query_timings: object) -> object: # noqa: D103
json_graph = json.loads(graph_json)
gather_timing_information(json_graph, query_timings)
total_time = float(json_graph.get("operator_timing") or json_graph.get("latency"))
table_head = """
<table class=\"styled-table\">
<thead>
<tr>
<th>Phase</th>
<th>Time</th>
<th>Percentage</th>
</tr>
</thead>"""
table_body = "<tbody>"
table_end = "</tbody></table>"
execution_time = query_timings.get_sum_of_all_timings()
all_phases = query_timings.get_phases()
query_timings.add_node_timing(NodeTiming("TOTAL TIME", total_time))
query_timings.add_node_timing(NodeTiming("Execution Time", execution_time))
all_phases = ["TOTAL TIME", "Execution Time", *all_phases]
for phase in all_phases:
summarized_phase = query_timings.get_summary_phase_timings(phase)
summarized_phase.calculate_percentage(total_time)
phase_column = f"<b>{phase}</b>" if phase == "TOTAL TIME" or phase == "Execution Time" else phase
table_body += f"""
<tr>
<td>{phase_column}</td>
<td>{summarized_phase.time}</td>
<td>{str(summarized_phase.percentage * 100)[:6]}%</td>
</tr>
"""
table_body += table_end
return table_head + table_body
def generate_tree_html(graph_json: object) -> str: # noqa: D103
json_graph = json.loads(graph_json)
cpu_time = float(json_graph["cpu_time"])
tree_prefix = '<div class="tf-tree tf-gap-sm"> \n <ul>'
tree_suffix = "</ul> </div>"
# first level of json is general overview
# TODO: make sure json output first level always has only 1 level # noqa: TD002, TD003
tree_body = generate_tree_recursive(json_graph["children"][0], cpu_time)
return tree_prefix + tree_body + tree_suffix
def generate_ipython(json_input: str) -> str: # noqa: D103
from IPython.core.display import HTML
html_output = generate_html(json_input, False) # noqa: F821
return HTML(
('\n ${CSS}\n ${LIBRARIES}\n <div class="chart" id="query-profile"></div>\n ${CHART_SCRIPT}\n ')
.replace("${CSS}", html_output["css"])
.replace("${CHART_SCRIPT}", html_output["chart_script"])
.replace("${LIBRARIES}", html_output["libraries"])
)
def generate_style_html(graph_json: str, include_meta_info: bool) -> None: # noqa: D103, FBT001
treeflex_css = '<link rel="stylesheet" href="https://unpkg.com/treeflex/dist/css/treeflex.css">\n'
css = "<style>\n"
css += qgraph_css + "\n"
css += "</style>\n"
return {"treeflex_css": treeflex_css, "duckdb_css": css, "libraries": "", "chart_script": ""}
def gather_timing_information(json: str, query_timings: object) -> None: # noqa: D103
# add up all of the times
# measure each time as a percentage of the total time.
# then you can return a list of [phase, time, percentage]
get_child_timings(json["children"][0], query_timings)
def translate_json_to_html(input_file: str, output_file: str) -> None: # noqa: D103
query_timings = AllTimings()
with open_utf8(input_file, "r") as f:
text = f.read()
html_output = generate_style_html(text, True)
timing_table = generate_timing_html(text, query_timings)
tree_output = generate_tree_html(text)
# finally create and write the html
with open_utf8(output_file, "w+") as f:
html = """<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width">
<title>Query Profile Graph for Query</title>
${TREEFLEX_CSS}
<style>
${DUCKDB_CSS}
</style>
</head>
<body>
<div id="meta-info"></div>
<div class="chart" id="query-profile">
${TIMING_TABLE}
</div>
${TREE}
</body>
</html>
"""
html = html.replace("${TREEFLEX_CSS}", html_output["treeflex_css"])
html = html.replace("${DUCKDB_CSS}", html_output["duckdb_css"])
html = html.replace("${TIMING_TABLE}", timing_table)
html = html.replace("${TREE}", tree_output)
f.write(html)
def main() -> None: # noqa: D103
parser = argparse.ArgumentParser(
prog="Query Graph Generator",
description="""Given a json profile output, generate a html file showing the query graph and
timings of operators""",
)
parser.add_argument("profile_input", help="profile input in json")
parser.add_argument("--out", required=False, default=False)
parser.add_argument("--open", required=False, action="store_true", default=True)
args = parser.parse_args()
input = args.profile_input
output = args.out
if not args.out:
if ".json" in input:
output = input.replace(".json", ".html")
else:
print("please provide profile output in json")
exit(1)
else:
if ".html" in args.out:
output = args.out
else:
print("please provide valid .html file for output name")
exit(1)
open_output = args.open
translate_json_to_html(input, output)
if open_output:
webbrowser.open(f"file://{Path(output).resolve()}", new=2)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,63 @@
"""DuckDB's SQL types."""
from _duckdb._sqltypes import (
BIGINT,
BIT,
BLOB,
BOOLEAN,
DATE,
DOUBLE,
FLOAT,
HUGEINT,
INTEGER,
INTERVAL,
SMALLINT,
SQLNULL,
TIME,
TIME_TZ,
TIMESTAMP,
TIMESTAMP_MS,
TIMESTAMP_NS,
TIMESTAMP_S,
TIMESTAMP_TZ,
TINYINT,
UBIGINT,
UHUGEINT,
UINTEGER,
USMALLINT,
UTINYINT,
UUID,
VARCHAR,
DuckDBPyType,
)
__all__ = [
"BIGINT",
"BIT",
"BLOB",
"BOOLEAN",
"DATE",
"DOUBLE",
"FLOAT",
"HUGEINT",
"INTEGER",
"INTERVAL",
"SMALLINT",
"SQLNULL",
"TIME",
"TIMESTAMP",
"TIMESTAMP_MS",
"TIMESTAMP_NS",
"TIMESTAMP_S",
"TIMESTAMP_TZ",
"TIME_TZ",
"TINYINT",
"UBIGINT",
"UHUGEINT",
"UINTEGER",
"USMALLINT",
"UTINYINT",
"UUID",
"VARCHAR",
"DuckDBPyType",
]

View File

@@ -0,0 +1,71 @@
"""DuckDB's SQL types. DEPRECATED. Please use `duckdb.sqltypes` instead."""
import warnings
from duckdb.sqltypes import (
BIGINT,
BIT,
BLOB,
BOOLEAN,
DATE,
DOUBLE,
FLOAT,
HUGEINT,
INTEGER,
INTERVAL,
SMALLINT,
SQLNULL,
TIME,
TIME_TZ,
TIMESTAMP,
TIMESTAMP_MS,
TIMESTAMP_NS,
TIMESTAMP_S,
TIMESTAMP_TZ,
TINYINT,
UBIGINT,
UHUGEINT,
UINTEGER,
USMALLINT,
UTINYINT,
UUID,
VARCHAR,
DuckDBPyType,
)
__all__ = [
"BIGINT",
"BIT",
"BLOB",
"BOOLEAN",
"DATE",
"DOUBLE",
"FLOAT",
"HUGEINT",
"INTEGER",
"INTERVAL",
"SMALLINT",
"SQLNULL",
"TIME",
"TIMESTAMP",
"TIMESTAMP_MS",
"TIMESTAMP_NS",
"TIMESTAMP_S",
"TIMESTAMP_TZ",
"TIME_TZ",
"TINYINT",
"UBIGINT",
"UHUGEINT",
"UINTEGER",
"USMALLINT",
"UTINYINT",
"UUID",
"VARCHAR",
"DuckDBPyType",
]
warnings.warn(
"`duckdb.typing` is deprecated and will be removed in a future version. Please use `duckdb.sqltypes` instead.",
DeprecationWarning,
stacklevel=2,
)

View File

@@ -0,0 +1,24 @@
# ruff: noqa: D100
import typing
def vectorized(func: typing.Callable[..., typing.Any]) -> typing.Callable[..., typing.Any]:
"""Decorate a function with annotated function parameters.
This allows DuckDB to infer that the function should be provided with pyarrow arrays and should expect
pyarrow array(s) as output.
"""
import types
from inspect import signature
new_func = types.FunctionType(func.__code__, func.__globals__, func.__name__, func.__defaults__, func.__closure__)
# Construct the annotations:
import pyarrow as pa
new_annotations = {}
sig = signature(func)
for param in sig.parameters:
new_annotations[param] = pa.lib.ChunkedArray
new_func.__annotations__ = new_annotations
return new_func

View File

@@ -0,0 +1 @@
# noqa: D104

View File

@@ -0,0 +1,270 @@
# ruff: noqa: D101, D104, D105, D107, ANN401
from typing import Any
from duckdb.sqltypes import (
BIGINT,
BIT,
BLOB,
BOOLEAN,
DATE,
DOUBLE,
FLOAT,
HUGEINT,
INTEGER,
INTERVAL,
SMALLINT,
SQLNULL,
TIME,
TIME_TZ,
TIMESTAMP,
TIMESTAMP_MS,
TIMESTAMP_NS,
TIMESTAMP_S,
TIMESTAMP_TZ,
TINYINT,
UBIGINT,
UHUGEINT,
UINTEGER,
USMALLINT,
UTINYINT,
UUID,
VARCHAR,
DuckDBPyType,
)
class Value:
def __init__(self, object: Any, type: DuckDBPyType) -> None:
self.object = object
self.type = type
def __repr__(self) -> str:
return str(self.object)
# Miscellaneous
class NullValue(Value):
def __init__(self) -> None:
super().__init__(None, SQLNULL)
class BooleanValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, BOOLEAN)
# Unsigned numerics
class UnsignedBinaryValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, UTINYINT)
class UnsignedShortValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, USMALLINT)
class UnsignedIntegerValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, UINTEGER)
class UnsignedLongValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, UBIGINT)
# Signed numerics
class BinaryValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, TINYINT)
class ShortValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, SMALLINT)
class IntegerValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, INTEGER)
class LongValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, BIGINT)
class HugeIntegerValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, HUGEINT)
class UnsignedHugeIntegerValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, UHUGEINT)
# Fractional
class FloatValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, FLOAT)
class DoubleValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, DOUBLE)
class DecimalValue(Value):
def __init__(self, object: Any, width: int, scale: int) -> None:
import duckdb
decimal_type = duckdb.decimal_type(width, scale)
super().__init__(object, decimal_type)
# String
class StringValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, VARCHAR)
class UUIDValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, UUID)
class BitValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, BIT)
class BlobValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, BLOB)
# Temporal
class DateValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, DATE)
class IntervalValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, INTERVAL)
class TimestampValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, TIMESTAMP)
class TimestampSecondValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, TIMESTAMP_S)
class TimestampMilisecondValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, TIMESTAMP_MS)
class TimestampNanosecondValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, TIMESTAMP_NS)
class TimestampTimeZoneValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, TIMESTAMP_TZ)
class TimeValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, TIME)
class TimeTimeZoneValue(Value):
def __init__(self, object: Any) -> None:
super().__init__(object, TIME_TZ)
class ListValue(Value):
def __init__(self, object: Any, child_type: DuckDBPyType) -> None:
import duckdb
list_type = duckdb.list_type(child_type)
super().__init__(object, list_type)
class StructValue(Value):
def __init__(self, object: Any, children: dict[str, DuckDBPyType]) -> None:
import duckdb
struct_type = duckdb.struct_type(children)
super().__init__(object, struct_type)
class MapValue(Value):
def __init__(self, object: Any, key_type: DuckDBPyType, value_type: DuckDBPyType) -> None:
import duckdb
map_type = duckdb.map_type(key_type, value_type)
super().__init__(object, map_type)
class UnionType(Value):
def __init__(self, object: Any, members: dict[str, DuckDBPyType]) -> None:
import duckdb
union_type = duckdb.union_type(members)
super().__init__(object, union_type)
# TODO: add EnumValue once `duckdb.enum_type` is added # noqa: TD002, TD003
__all__ = [
"BinaryValue",
"BitValue",
"BlobValue",
"BooleanValue",
"DateValue",
"DecimalValue",
"DoubleValue",
"FloatValue",
"HugeIntegerValue",
"IntegerValue",
"IntervalValue",
"LongValue",
"NullValue",
"ShortValue",
"StringValue",
"TimeTimeZoneValue",
"TimeValue",
"TimestampMilisecondValue",
"TimestampNanosecondValue",
"TimestampSecondValue",
"TimestampTimeZoneValue",
"TimestampValue",
"UUIDValue",
"UnsignedBinaryValue",
"UnsignedHugeIntegerValue",
"UnsignedIntegerValue",
"UnsignedLongValue",
"UnsignedShortValue",
"Value",
]