from __future__ import annotations
import operator
from functools import reduce
from typing import TYPE_CHECKING, Iterable, Sequence, cast
import dask.dataframe as dd
import pandas as pd
from narwhals._compliant import CompliantThen, CompliantWhen, LazyNamespace
from narwhals._compliant.namespace import DepthTrackingNamespace
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
from narwhals._dask.selectors import DaskSelectorNamespace
from narwhals._dask.utils import (
align_series_full_broadcast,
narwhals_to_native_dtype,
validate_comparand,
)
from narwhals._expression_parsing import (
combine_alias_output_names,
combine_evaluate_output_names,
)
from narwhals._utils import Implementation
if TYPE_CHECKING:
import dask.dataframe.dask_expr as dx
from narwhals._utils import Version
from narwhals.dtypes import DType
from narwhals.typing import ConcatMethod, NonNestedLiteral
class DaskNamespace(
LazyNamespace[DaskLazyFrame, DaskExpr, dd.DataFrame],
DepthTrackingNamespace[DaskLazyFrame, DaskExpr],
):
_implementation: Implementation = Implementation.DASK
@property
def selectors(self) -> DaskSelectorNamespace:
return DaskSelectorNamespace.from_namespace(self)
@property
def _expr(self) -> type[DaskExpr]:
return DaskExpr
@property
def _lazyframe(self) -> type[DaskLazyFrame]:
return DaskLazyFrame
def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> None:
self._backend_version = backend_version
self._version = version
def lit(self, value: NonNestedLiteral, dtype: DType | type[DType] | None) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
if dtype is not None:
native_dtype = narwhals_to_native_dtype(dtype, self._version)
native_pd_series = pd.Series([value], dtype=native_dtype, name="literal")
else:
native_pd_series = pd.Series([value], name="literal")
npartitions = df._native_frame.npartitions
dask_series = dd.from_pandas(native_pd_series, npartitions=npartitions)
return [dask_series[0].to_series()]
return self._expr(
func,
depth=0,
function_name="lit",
evaluate_output_names=lambda _df: ["literal"],
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
)
def len(self) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
# We don't allow dataframes with 0 columns, so `[0]` is safe.
return [df._native_frame[df.columns[0]].size.to_series()]
return self._expr(
func,
depth=0,
function_name="len",
evaluate_output_names=lambda _df: ["len"],
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
)
def all_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = align_series_full_broadcast(
df, *(s for _expr in exprs for s in _expr(df))
)
return [reduce(operator.and_, series)]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="all_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
backend_version=self._backend_version,
version=self._version,
)
def any_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = align_series_full_broadcast(
df, *(s for _expr in exprs for s in _expr(df))
)
return [reduce(operator.or_, series)]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="any_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
backend_version=self._backend_version,
version=self._version,
)
def sum_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = align_series_full_broadcast(
df, *(s for _expr in exprs for s in _expr(df))
)
return [dd.concat(series, axis=1).sum(axis=1)]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="sum_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
backend_version=self._backend_version,
version=self._version,
)
def concat(
self, items: Iterable[DaskLazyFrame], *, how: ConcatMethod
) -> DaskLazyFrame:
if not items:
msg = "No items to concatenate" # pragma: no cover
raise AssertionError(msg)
dfs = [i._native_frame for i in items]
cols_0 = dfs[0].columns
if how == "vertical":
for i, df in enumerate(dfs[1:], start=1):
cols_current = df.columns
if not (
(len(cols_current) == len(cols_0)) and (cols_current == cols_0).all()
):
msg = (
"unable to vstack, column names don't match:\n"
f" - dataframe 0: {cols_0.to_list()}\n"
f" - dataframe {i}: {cols_current.to_list()}\n"
)
raise TypeError(msg)
return DaskLazyFrame(
dd.concat(dfs, axis=0, join="inner"),
backend_version=self._backend_version,
version=self._version,
)
if how == "diagonal":
return DaskLazyFrame(
dd.concat(dfs, axis=0, join="outer"),
backend_version=self._backend_version,
version=self._version,
)
raise NotImplementedError
def mean_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
expr_results = [s for _expr in exprs for s in _expr(df)]
series = align_series_full_broadcast(df, *(s.fillna(0) for s in expr_results))
non_na = align_series_full_broadcast(
df, *(1 - s.isna() for s in expr_results)
)
num = reduce(lambda x, y: x + y, series) # pyright: ignore[reportOperatorIssue]
den = reduce(lambda x, y: x + y, non_na) # pyright: ignore[reportOperatorIssue]
return [cast("dx.Series", num / den)] # pyright: ignore[reportOperatorIssue]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="mean_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
backend_version=self._backend_version,
version=self._version,
)
def min_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = align_series_full_broadcast(
df, *(s for _expr in exprs for s in _expr(df))
)
return [dd.concat(series, axis=1).min(axis=1)]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="min_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
backend_version=self._backend_version,
version=self._version,
)
def max_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = align_series_full_broadcast(
df, *(s for _expr in exprs for s in _expr(df))
)
return [dd.concat(series, axis=1).max(axis=1)]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="max_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
backend_version=self._backend_version,
version=self._version,
)
def when(self, predicate: DaskExpr) -> DaskWhen:
return DaskWhen.from_expr(predicate, context=self)
def concat_str(
self, *exprs: DaskExpr, separator: str, ignore_nulls: bool
) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
expr_results = [s for _expr in exprs for s in _expr(df)]
series = (
s.astype(str) for s in align_series_full_broadcast(df, *expr_results)
)
null_mask = [s.isna() for s in align_series_full_broadcast(df, *expr_results)]
if not ignore_nulls:
null_mask_result = reduce(operator.or_, null_mask)
result = reduce(lambda x, y: x + separator + y, series).where(
~null_mask_result, None
)
else:
init_value, *values = [
s.where(~nm, "") for s, nm in zip(series, null_mask)
]
separators = (
nm.map({True: "", False: separator}, meta=str)
for nm in null_mask[:-1]
)
result = reduce(
operator.add, (s + v for s, v in zip(separators, values)), init_value
)
return [result]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="concat_str",
evaluate_output_names=getattr(
exprs[0], "_evaluate_output_names", lambda _df: ["literal"]
),
alias_output_names=getattr(exprs[0], "_alias_output_names", None),
backend_version=self._backend_version,
version=self._version,
)
class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]):
@property
def _then(self) -> type[DaskThen]:
return DaskThen
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
condition = self._condition(df)[0]
if isinstance(self._then_value, DaskExpr):
then_value = self._then_value(df)[0]
else:
then_value = self._then_value
(then_series,) = align_series_full_broadcast(df, then_value)
validate_comparand(condition, then_series)
if self._otherwise_value is None:
return [then_series.where(condition)]
if isinstance(self._otherwise_value, DaskExpr):
otherwise_value = self._otherwise_value(df)[0]
else:
return [then_series.where(condition, self._otherwise_value)] # pyright: ignore[reportArgumentType]
(otherwise_series,) = align_series_full_broadcast(df, otherwise_value)
validate_comparand(condition, otherwise_series)
return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType]
class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr], DaskExpr): ...