from __future__ import annotations
import operator
from typing import (
TYPE_CHECKING,
Any,
Iterable,
Literal,
Mapping,
Sequence,
cast,
overload,
)
import polars as pl
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.series import PolarsSeries
from narwhals._polars.utils import extract_args_kwargs, narwhals_to_native_dtype
from narwhals._utils import Implementation, requires
from narwhals.dependencies import is_numpy_array_2d
from narwhals.dtypes import DType
if TYPE_CHECKING:
from datetime import timezone
from narwhals._compliant import CompliantSelectorNamespace, CompliantWhen
from narwhals._polars.dataframe import Method, PolarsDataFrame, PolarsLazyFrame
from narwhals._polars.typing import FrameT
from narwhals._utils import Version, _FullContext
from narwhals.schema import Schema
from narwhals.typing import Into1DArray, TimeUnit, _2DArray
class PolarsNamespace:
all: Method[PolarsExpr]
col: Method[PolarsExpr]
exclude: Method[PolarsExpr]
all_horizontal: Method[PolarsExpr]
any_horizontal: Method[PolarsExpr]
sum_horizontal: Method[PolarsExpr]
min_horizontal: Method[PolarsExpr]
max_horizontal: Method[PolarsExpr]
# NOTE: `pyright` accepts, `mypy` doesn't highlight the issue
# error: Type argument "PolarsExpr" of "CompliantWhen" must be a subtype of "CompliantExpr[Any, Any]"
when: Method[CompliantWhen[PolarsDataFrame, PolarsSeries, PolarsExpr]] # type: ignore[type-var]
def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> None:
self._backend_version = backend_version
self._implementation = Implementation.POLARS
self._version = version
def __getattr__(self, attr: str) -> Any:
def func(*args: Any, **kwargs: Any) -> Any:
pos, kwds = extract_args_kwargs(args, kwargs)
return self._expr(
getattr(pl, attr)(*pos, **kwds),
version=self._version,
backend_version=self._backend_version,
)
return func
@property
def _dataframe(self) -> type[PolarsDataFrame]:
from narwhals._polars.dataframe import PolarsDataFrame
return PolarsDataFrame
@property
def _lazyframe(self) -> type[PolarsLazyFrame]:
from narwhals._polars.dataframe import PolarsLazyFrame
return PolarsLazyFrame
@property
def _expr(self) -> type[PolarsExpr]:
return PolarsExpr
@property
def _series(self) -> type[PolarsSeries]:
return PolarsSeries
@overload
def from_native(self, data: pl.DataFrame, /) -> PolarsDataFrame: ...
@overload
def from_native(self, data: pl.LazyFrame, /) -> PolarsLazyFrame: ...
@overload
def from_native(self, data: pl.Series, /) -> PolarsSeries: ...
def from_native(
self, data: pl.DataFrame | pl.LazyFrame | pl.Series | Any, /
) -> PolarsDataFrame | PolarsLazyFrame | PolarsSeries:
if self._dataframe._is_native(data):
return self._dataframe.from_native(data, context=self)
elif self._series._is_native(data):
return self._series.from_native(data, context=self)
elif self._lazyframe._is_native(data):
return self._lazyframe.from_native(data, context=self)
else: # pragma: no cover
msg = f"Unsupported type: {type(data).__name__!r}"
raise TypeError(msg)
@overload
def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> PolarsSeries: ...
@overload
def from_numpy(
self,
data: _2DArray,
/,
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
) -> PolarsDataFrame: ...
def from_numpy(
self,
data: Into1DArray | _2DArray,
/,
schema: Mapping[str, DType] | Schema | Sequence[str] | None = None,
) -> PolarsDataFrame | PolarsSeries:
if is_numpy_array_2d(data):
return self._dataframe.from_numpy(data, schema=schema, context=self)
return self._series.from_numpy(data, context=self) # pragma: no cover
@requires.backend_version(
(1, 0, 0), "Please use `col` for columns selection instead."
)
def nth(self, *indices: int) -> PolarsExpr:
return self._expr(
pl.nth(*indices), version=self._version, backend_version=self._backend_version
)
def len(self) -> PolarsExpr:
if self._backend_version < (0, 20, 5):
return self._expr(
pl.count().alias("len"),
version=self._version,
backend_version=self._backend_version,
)
return self._expr(
pl.len(), version=self._version, backend_version=self._backend_version
)
def concat(
self,
items: Iterable[FrameT],
*,
how: Literal["vertical", "horizontal", "diagonal"],
) -> PolarsDataFrame | PolarsLazyFrame:
result = pl.concat((item.native for item in items), how=how)
if isinstance(result, pl.DataFrame):
return self._dataframe(
result, backend_version=self._backend_version, version=self._version
)
return self._lazyframe.from_native(result, context=self)
def lit(self, value: Any, dtype: DType | type[DType] | None) -> PolarsExpr:
if dtype is not None:
return self._expr(
pl.lit(
value,
dtype=narwhals_to_native_dtype(
dtype, self._version, self._backend_version
),
),
version=self._version,
backend_version=self._backend_version,
)
return self._expr(
pl.lit(value), version=self._version, backend_version=self._backend_version
)
def mean_horizontal(self, *exprs: PolarsExpr) -> PolarsExpr:
if self._backend_version < (0, 20, 8):
return self._expr(
pl.sum_horizontal(e._native_expr for e in exprs)
/ pl.sum_horizontal(1 - e.is_null()._native_expr for e in exprs),
version=self._version,
backend_version=self._backend_version,
)
return self._expr(
pl.mean_horizontal(e._native_expr for e in exprs),
version=self._version,
backend_version=self._backend_version,
)
def concat_str(
self, *exprs: PolarsExpr, separator: str, ignore_nulls: bool
) -> PolarsExpr:
pl_exprs: list[pl.Expr] = [expr._native_expr for expr in exprs]
if self._backend_version < (0, 20, 6):
null_mask = [expr.is_null() for expr in pl_exprs]
sep = pl.lit(separator)
if not ignore_nulls:
null_mask_result = pl.any_horizontal(*null_mask)
output_expr = pl.reduce(
lambda x, y: x.cast(pl.String()) + sep + y.cast(pl.String()), # type: ignore[arg-type,return-value]
pl_exprs,
)
result = pl.when(~null_mask_result).then(output_expr)
else:
init_value, *values = [
pl.when(nm).then(pl.lit("")).otherwise(expr.cast(pl.String()))
for expr, nm in zip(pl_exprs, null_mask)
]
separators = [
pl.when(~nm).then(sep).otherwise(pl.lit("")) for nm in null_mask[:-1]
]
result = pl.fold( # type: ignore[assignment]
acc=init_value,
function=operator.add,
exprs=[s + v for s, v in zip(separators, values)],
)
return self._expr(
result, version=self._version, backend_version=self._backend_version
)
return self._expr(
pl.concat_str(pl_exprs, separator=separator, ignore_nulls=ignore_nulls),
version=self._version,
backend_version=self._backend_version,
)
# NOTE: Implementation is too different to annotate correctly (vs other `*SelectorNamespace`)
# 1. Others have lots of private stuff for code reuse
# i. None of that is useful here
# 2. We don't have a `PolarsSelector` abstraction, and just use `PolarsExpr`
@property
def selectors(self) -> CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]:
return cast(
"CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]",
PolarsSelectorNamespace(self),
)
class PolarsSelectorNamespace:
def __init__(self, context: _FullContext, /) -> None:
self._implementation = context._implementation
self._backend_version = context._backend_version
self._version = context._version
def by_dtype(self, dtypes: Iterable[DType]) -> PolarsExpr:
native_dtypes = [
narwhals_to_native_dtype(
dtype, self._version, self._backend_version
).__class__
if isinstance(dtype, type) and issubclass(dtype, DType)
else narwhals_to_native_dtype(dtype, self._version, self._backend_version)
for dtype in dtypes
]
return PolarsExpr(
pl.selectors.by_dtype(native_dtypes),
version=self._version,
backend_version=self._backend_version,
)
def matches(self, pattern: str) -> PolarsExpr:
return PolarsExpr(
pl.selectors.matches(pattern=pattern),
version=self._version,
backend_version=self._backend_version,
)
def numeric(self) -> PolarsExpr:
return PolarsExpr(
pl.selectors.numeric(),
version=self._version,
backend_version=self._backend_version,
)
def boolean(self) -> PolarsExpr:
return PolarsExpr(
pl.selectors.boolean(),
version=self._version,
backend_version=self._backend_version,
)
def string(self) -> PolarsExpr:
return PolarsExpr(
pl.selectors.string(),
version=self._version,
backend_version=self._backend_version,
)
def categorical(self) -> PolarsExpr:
return PolarsExpr(
pl.selectors.categorical(),
version=self._version,
backend_version=self._backend_version,
)
def all(self) -> PolarsExpr:
return PolarsExpr(
pl.selectors.all(),
version=self._version,
backend_version=self._backend_version,
)
def datetime(
self,
time_unit: TimeUnit | Iterable[TimeUnit] | None,
time_zone: str | timezone | Iterable[str | timezone | None] | None,
) -> PolarsExpr:
return PolarsExpr(
pl.selectors.datetime(time_unit=time_unit, time_zone=time_zone), # type: ignore[arg-type]
version=self._version,
backend_version=self._backend_version,
)