from __future__ import annotations
import contextlib
from functools import reduce
from operator import and_
from typing import TYPE_CHECKING, Any, Iterator, Mapping, Sequence
import duckdb
from duckdb import FunctionExpression, StarExpression
from narwhals._duckdb.utils import (
DeferredTimeZone,
col,
evaluate_exprs,
generate_partition_by_sql,
lit,
native_to_narwhals_dtype,
)
from narwhals._utils import (
Implementation,
Version,
generate_temporary_column_name,
not_implemented,
parse_columns_to_drop,
parse_version,
validate_backend_version,
)
from narwhals.dependencies import get_duckdb
from narwhals.exceptions import InvalidOperationError
from narwhals.typing import CompliantLazyFrame
if TYPE_CHECKING:
from types import ModuleType
import pandas as pd
import pyarrow as pa
from duckdb import Expression
from duckdb.typing import DuckDBPyType
from typing_extensions import Self, TypeIs
from narwhals._compliant.typing import CompliantDataFrameAny
from narwhals._duckdb.expr import DuckDBExpr
from narwhals._duckdb.group_by import DuckDBGroupBy
from narwhals._duckdb.namespace import DuckDBNamespace
from narwhals._duckdb.series import DuckDBInterchangeSeries
from narwhals._utils import _FullContext
from narwhals.dataframe import LazyFrame
from narwhals.dtypes import DType
from narwhals.stable.v1 import DataFrame as DataFrameV1
from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
with contextlib.suppress(ImportError): # requires duckdb>=1.3.0
from duckdb import SQLExpression
class DuckDBLazyFrame(
CompliantLazyFrame[
"DuckDBExpr",
"duckdb.DuckDBPyRelation",
"LazyFrame[duckdb.DuckDBPyRelation] | DataFrameV1[duckdb.DuckDBPyRelation]",
]
):
_implementation = Implementation.DUCKDB
def __init__(
self,
df: duckdb.DuckDBPyRelation,
*,
backend_version: tuple[int, ...],
version: Version,
) -> None:
self._native_frame: duckdb.DuckDBPyRelation = df
self._version = version
self._backend_version = backend_version
self._cached_native_schema: dict[str, DuckDBPyType] | None = None
self._cached_columns: list[str] | None = None
validate_backend_version(self._implementation, self._backend_version)
@staticmethod
def _is_native(obj: duckdb.DuckDBPyRelation | Any) -> TypeIs[duckdb.DuckDBPyRelation]:
return isinstance(obj, duckdb.DuckDBPyRelation)
@classmethod
def from_native(
cls, data: duckdb.DuckDBPyRelation, /, *, context: _FullContext
) -> Self:
return cls(
data, backend_version=context._backend_version, version=context._version
)
def to_narwhals(
self, *args: Any, **kwds: Any
) -> LazyFrame[duckdb.DuckDBPyRelation] | DataFrameV1[duckdb.DuckDBPyRelation]:
if self._version is Version.MAIN:
return self._version.lazyframe(self, level="lazy")
from narwhals.stable.v1 import DataFrame as DataFrameV1
return DataFrameV1(self, level="interchange") # type: ignore[no-any-return]
def __narwhals_dataframe__(self) -> Self: # pragma: no cover
# Keep around for backcompat.
if self._version is not Version.V1:
msg = "__narwhals_dataframe__ is not implemented for DuckDBLazyFrame"
raise AttributeError(msg)
return self
def __narwhals_lazyframe__(self) -> Self:
return self
def __native_namespace__(self) -> ModuleType:
return get_duckdb() # type: ignore[no-any-return]
def __narwhals_namespace__(self) -> DuckDBNamespace:
from narwhals._duckdb.namespace import DuckDBNamespace
return DuckDBNamespace(
backend_version=self._backend_version, version=self._version
)
def get_column(self, name: str) -> DuckDBInterchangeSeries:
from narwhals._duckdb.series import DuckDBInterchangeSeries
return DuckDBInterchangeSeries(self.native.select(name), version=self._version)
def _iter_columns(self) -> Iterator[Expression]:
for name in self.columns:
yield col(name)
def collect(
self, backend: ModuleType | Implementation | str | None, **kwargs: Any
) -> CompliantDataFrameAny:
if backend is None or backend is Implementation.PYARROW:
import pyarrow as pa # ignore-banned-import
from narwhals._arrow.dataframe import ArrowDataFrame
return ArrowDataFrame(
self.native.arrow(),
backend_version=parse_version(pa),
version=self._version,
validate_column_names=True,
)
if backend is Implementation.PANDAS:
import pandas as pd # ignore-banned-import
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
return PandasLikeDataFrame(
self.native.df(),
implementation=Implementation.PANDAS,
backend_version=parse_version(pd),
version=self._version,
validate_column_names=True,
)
if backend is Implementation.POLARS:
import polars as pl # ignore-banned-import
from narwhals._polars.dataframe import PolarsDataFrame
return PolarsDataFrame(
self.native.pl(), backend_version=parse_version(pl), version=self._version
)
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise ValueError(msg) # pragma: no cover
def head(self, n: int) -> Self:
return self._with_native(self.native.limit(n))
def simple_select(self, *column_names: str) -> Self:
return self._with_native(self.native.select(*column_names))
def aggregate(self, *exprs: DuckDBExpr) -> Self:
selection = [val.alias(name) for name, val in evaluate_exprs(self, *exprs)]
return self._with_native(self.native.aggregate(selection)) # type: ignore[arg-type]
def select(self, *exprs: DuckDBExpr) -> Self:
selection = (val.alias(name) for name, val in evaluate_exprs(self, *exprs))
return self._with_native(self.native.select(*selection))
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
columns_to_drop = parse_columns_to_drop(self, columns, strict=strict)
selection = (name for name in self.columns if name not in columns_to_drop)
return self._with_native(self.native.select(*selection))
def lazy(self, *, backend: Implementation | None = None) -> Self:
# The `backend`` argument has no effect but we keep it here for
# backwards compatibility because in `narwhals.stable.v1`
# function `.from_native()` will return a DataFrame for DuckDB.
if backend is not None: # pragma: no cover
msg = "`backend` argument is not supported for DuckDB"
raise ValueError(msg)
return self
def with_columns(self, *exprs: DuckDBExpr) -> Self:
new_columns_map = dict(evaluate_exprs(self, *exprs))
result = [
new_columns_map.pop(name).alias(name)
if name in new_columns_map
else col(name)
for name in self.columns
]
result.extend(value.alias(name) for name, value in new_columns_map.items())
return self._with_native(self.native.select(*result))
def filter(self, predicate: DuckDBExpr) -> Self:
# `[0]` is safe as the predicate's expression only returns a single column
mask = predicate(self)[0]
return self._with_native(self.native.filter(mask))
@property
def schema(self) -> dict[str, DType]:
if self._cached_native_schema is None:
# Note: prefer `self._cached_native_schema` over `functools.cached_property`
# due to Python3.13 failures.
self._cached_native_schema = dict(zip(self.columns, self.native.types))
deferred_time_zone = DeferredTimeZone(self.native)
return {
column_name: native_to_narwhals_dtype(
duckdb_dtype, self._version, deferred_time_zone
)
for column_name, duckdb_dtype in zip(self.native.columns, self.native.types)
}
@property
def columns(self) -> list[str]:
if self._cached_columns is None:
self._cached_columns = (
list(self.schema)
if self._cached_native_schema is not None
else self.native.columns
)
return self._cached_columns
def to_pandas(self) -> pd.DataFrame:
# only if version is v1, keep around for backcompat
import pandas as pd # ignore-banned-import()
if parse_version(pd) >= (1, 0, 0):
return self.native.df()
else: # pragma: no cover
msg = f"Conversion to pandas requires 'pandas>=1.0.0', found {pd.__version__}"
raise NotImplementedError(msg)
def to_arrow(self) -> pa.Table:
# only if version is v1, keep around for backcompat
return self.native.arrow()
def _with_version(self, version: Version) -> Self:
return self.__class__(
self.native, version=version, backend_version=self._backend_version
)
def _with_native(self, df: duckdb.DuckDBPyRelation) -> Self:
return self.__class__(
df, backend_version=self._backend_version, version=self._version
)
def group_by(
self, keys: Sequence[str] | Sequence[DuckDBExpr], *, drop_null_keys: bool
) -> DuckDBGroupBy:
from narwhals._duckdb.group_by import DuckDBGroupBy
return DuckDBGroupBy(self, keys, drop_null_keys=drop_null_keys)
def rename(self, mapping: Mapping[str, str]) -> Self:
df = self.native
selection = (
col(name).alias(mapping[name]) if name in mapping else col(name)
for name in df.columns
)
return self._with_native(self.native.select(*selection))
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self:
native_how = "outer" if how == "full" else how
if native_how == "cross":
if self._backend_version < (1, 1, 4):
msg = f"'duckdb>=1.1.4' is required for cross-join, found version: {self._backend_version}"
raise NotImplementedError(msg)
rel = self.native.set_alias("lhs").cross(other.native.set_alias("rhs"))
else:
# help mypy
assert left_on is not None # noqa: S101
assert right_on is not None # noqa: S101
it = (
col(f'lhs."{left}"') == col(f'rhs."{right}"')
for left, right in zip(left_on, right_on)
)
condition: Expression = reduce(and_, it)
rel = self.native.set_alias("lhs").join(
other.native.set_alias("rhs"),
# NOTE: Fixed in `--pre` https://github.com/duckdb/duckdb/pull/16933
condition=condition, # type: ignore[arg-type, unused-ignore]
how=native_how,
)
if native_how in {"inner", "left", "cross", "outer"}:
select = [col(f'lhs."{x}"') for x in self.columns]
for name in other.columns:
col_in_lhs: bool = name in self.columns
if native_how == "outer" and not col_in_lhs:
select.append(col(f'rhs."{name}"'))
elif (native_how == "outer") or (
col_in_lhs and (right_on is None or name not in right_on)
):
select.append(col(f'rhs."{name}"').alias(f"{name}{suffix}"))
elif right_on is None or name not in right_on:
select.append(col(name))
res = rel.select(*select).set_alias(self.native.alias)
else: # semi, anti
res = rel.select("lhs.*").set_alias(self.native.alias)
return self._with_native(res)
def join_asof(
self,
other: Self,
*,
left_on: str,
right_on: str,
by_left: Sequence[str] | None,
by_right: Sequence[str] | None,
strategy: AsofJoinStrategy,
suffix: str,
) -> Self:
lhs = self.native
rhs = other.native
conditions: list[Expression] = []
if by_left is not None and by_right is not None:
conditions.extend(
col(f'lhs."{left}"') == col(f'rhs."{right}"')
for left, right in zip(by_left, by_right)
)
else:
by_left = by_right = []
if strategy == "backward":
conditions.append(col(f'lhs."{left_on}"') >= col(f'rhs."{right_on}"'))
elif strategy == "forward":
conditions.append(col(f'lhs."{left_on}"') <= col(f'rhs."{right_on}"'))
else:
msg = "Only 'backward' and 'forward' strategies are currently supported for DuckDB"
raise NotImplementedError(msg)
condition: Expression = reduce(and_, conditions)
select = ["lhs.*"]
for name in rhs.columns:
if name in lhs.columns and (
right_on is None or name not in {right_on, *by_right}
):
select.append(f'rhs."{name}" as "{name}{suffix}"')
elif right_on is None or name not in {right_on, *by_right}:
select.append(str(col(name)))
# Replace with Python API call once
# https://github.com/duckdb/duckdb/discussions/16947 is addressed.
query = f"""
SELECT {",".join(select)}
FROM lhs
ASOF LEFT JOIN rhs
ON {condition}
""" # noqa: S608
return self._with_native(duckdb.sql(query))
def collect_schema(self) -> dict[str, DType]:
return self.schema
def unique(
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
) -> Self:
if subset_ := subset if keep == "any" else (subset or self.columns):
if self._backend_version < (1, 3):
msg = (
"At least version 1.3 of DuckDB is required for `unique` operation\n"
"with `subset` specified."
)
raise NotImplementedError(msg)
# Sanitise input
if error := self._check_columns_exist(subset_):
raise error
idx_name = generate_temporary_column_name(8, self.columns)
count_name = generate_temporary_column_name(8, [*self.columns, idx_name])
partition_by_sql = generate_partition_by_sql(*(subset_))
name = count_name if keep == "none" else idx_name
idx_expr = SQLExpression(
f"{FunctionExpression('row_number')} over ({partition_by_sql})"
).alias(idx_name)
count_expr = SQLExpression(
f"{FunctionExpression('count', StarExpression())} over ({partition_by_sql})"
).alias(count_name)
return self._with_native(
self.native.select(StarExpression(), idx_expr, count_expr)
.filter(col(name) == lit(1))
.select(StarExpression(exclude=[count_name, idx_name]))
)
return self._with_native(self.native.unique(", ".join(self.columns)))
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
if isinstance(descending, bool):
descending = [descending] * len(by)
if nulls_last:
it = (
col(name).nulls_last() if not desc else col(name).desc().nulls_last()
for name, desc in zip(by, descending)
)
else:
it = (
col(name).nulls_first() if not desc else col(name).desc().nulls_first()
for name, desc in zip(by, descending)
)
return self._with_native(self.native.sort(*it))
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
subset_ = subset if subset is not None else self.columns
keep_condition = reduce(and_, (col(name).isnotnull() for name in subset_))
return self._with_native(self.native.filter(keep_condition))
def explode(self, columns: Sequence[str]) -> Self:
dtypes = self._version.dtypes
schema = self.collect_schema()
for name in columns:
dtype = schema[name]
if dtype != dtypes.List:
msg = (
f"`explode` operation not supported for dtype `{dtype}`, "
"expected List type"
)
raise InvalidOperationError(msg)
if len(columns) != 1:
msg = (
"Exploding on multiple columns is not supported with DuckDB backend since "
"we cannot guarantee that the exploded columns have matching element counts."
)
raise NotImplementedError(msg)
col_to_explode = col(columns[0])
rel = self.native
original_columns = self.columns
not_null_condition = col_to_explode.isnotnull() & FunctionExpression(
"len", col_to_explode
) > lit(0)
non_null_rel = rel.filter(not_null_condition).select(
*(
FunctionExpression("unnest", col_to_explode).alias(name)
if name in columns
else name
for name in original_columns
)
)
null_rel = rel.filter(~not_null_condition).select(
*(
lit(None).alias(name) if name in columns else name
for name in original_columns
)
)
return self._with_native(non_null_rel.union(null_rel))
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self:
index_ = [] if index is None else index
on_ = [c for c in self.columns if c not in index_] if on is None else on
if variable_name == "":
msg = "`variable_name` cannot be empty string for duckdb backend."
raise NotImplementedError(msg)
if value_name == "":
msg = "`value_name` cannot be empty string for duckdb backend."
raise NotImplementedError(msg)
unpivot_on = ", ".join(str(col(name)) for name in on_)
rel = self.native # noqa: F841
# Replace with Python API once
# https://github.com/duckdb/duckdb/discussions/16980 is addressed.
query = f"""
unpivot rel
on {unpivot_on}
into
name "{variable_name}"
value "{value_name}"
"""
return self._with_native(
duckdb.sql(query).select(*[*index_, variable_name, value_name])
)
gather_every = not_implemented.deprecated(
"`LazyFrame.gather_every` is deprecated and will be removed in a future version."
)
tail = not_implemented.deprecated(
"`LazyFrame.tail` is deprecated and will be removed in a future version."
)
with_row_index = not_implemented()