from __future__ import annotations
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Any,
Iterable,
Iterator,
Mapping,
TypeVar,
cast,
overload,
)
import polars as pl
from narwhals._utils import Version, _DeferredIterable, isinstance_or_issubclass
from narwhals.exceptions import (
ColumnNotFoundError,
ComputeError,
DuplicateError,
InvalidOperationError,
NarwhalsError,
ShapeError,
)
if TYPE_CHECKING:
from typing_extensions import TypeIs
from narwhals._utils import _StoresNative
from narwhals.dtypes import DType
T = TypeVar("T")
NativeT = TypeVar(
"NativeT", bound="pl.DataFrame | pl.LazyFrame | pl.Series | pl.Expr"
)
@overload
def extract_native(obj: _StoresNative[NativeT]) -> NativeT: ...
@overload
def extract_native(obj: T) -> T: ...
def extract_native(obj: _StoresNative[NativeT] | T) -> NativeT | T:
return obj.native if _is_compliant_polars(obj) else obj
def _is_compliant_polars(
obj: _StoresNative[NativeT] | Any,
) -> TypeIs[_StoresNative[NativeT]]:
from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.series import PolarsSeries
return isinstance(obj, (PolarsDataFrame, PolarsLazyFrame, PolarsSeries, PolarsExpr))
def extract_args_kwargs(
args: Iterable[Any], kwds: Mapping[str, Any], /
) -> tuple[Iterator[Any], dict[str, Any]]:
it_args = (extract_native(arg) for arg in args)
return it_args, {k: extract_native(v) for k, v in kwds.items()}
@lru_cache(maxsize=16)
def native_to_narwhals_dtype( # noqa: C901, PLR0912
dtype: pl.DataType, version: Version, backend_version: tuple[int, ...]
) -> DType:
dtypes = version.dtypes
if dtype == pl.Float64:
return dtypes.Float64()
if dtype == pl.Float32:
return dtypes.Float32()
if hasattr(pl, "Int128") and dtype == pl.Int128: # pragma: no cover
# Not available for Polars pre 1.8.0
return dtypes.Int128()
if dtype == pl.Int64:
return dtypes.Int64()
if dtype == pl.Int32:
return dtypes.Int32()
if dtype == pl.Int16:
return dtypes.Int16()
if dtype == pl.Int8:
return dtypes.Int8()
if hasattr(pl, "UInt128") and dtype == pl.UInt128: # pragma: no cover
# Not available for Polars pre 1.8.0
return dtypes.UInt128()
if dtype == pl.UInt64:
return dtypes.UInt64()
if dtype == pl.UInt32:
return dtypes.UInt32()
if dtype == pl.UInt16:
return dtypes.UInt16()
if dtype == pl.UInt8:
return dtypes.UInt8()
if dtype == pl.String:
return dtypes.String()
if dtype == pl.Boolean:
return dtypes.Boolean()
if dtype == pl.Object:
return dtypes.Object()
if dtype == pl.Categorical:
return dtypes.Categorical()
if isinstance_or_issubclass(dtype, pl.Enum):
if version is Version.V1:
return dtypes.Enum() # type: ignore[call-arg]
categories = _DeferredIterable(
dtype.categories.to_list
if backend_version >= (0, 20, 4)
else lambda: cast("list[str]", dtype.categories)
)
return dtypes.Enum(categories)
if dtype == pl.Date:
return dtypes.Date()
if isinstance_or_issubclass(dtype, pl.Datetime):
return (
dtypes.Datetime()
if dtype is pl.Datetime
else dtypes.Datetime(dtype.time_unit, dtype.time_zone)
)
if isinstance_or_issubclass(dtype, pl.Duration):
return (
dtypes.Duration()
if dtype is pl.Duration
else dtypes.Duration(dtype.time_unit)
)
if isinstance_or_issubclass(dtype, pl.Struct):
fields = [
dtypes.Field(name, native_to_narwhals_dtype(tp, version, backend_version))
for name, tp in dtype
]
return dtypes.Struct(fields)
if isinstance_or_issubclass(dtype, pl.List):
return dtypes.List(
native_to_narwhals_dtype(dtype.inner, version, backend_version)
)
if isinstance_or_issubclass(dtype, pl.Array):
outer_shape = dtype.width if backend_version < (0, 20, 30) else dtype.size
return dtypes.Array(
native_to_narwhals_dtype(dtype.inner, version, backend_version), outer_shape
)
if dtype == pl.Decimal:
return dtypes.Decimal()
if dtype == pl.Time:
return dtypes.Time()
if dtype == pl.Binary:
return dtypes.Binary()
return dtypes.Unknown()
def narwhals_to_native_dtype( # noqa: C901, PLR0912
dtype: DType | type[DType], version: Version, backend_version: tuple[int, ...]
) -> pl.DataType:
dtypes = version.dtypes
if dtype == dtypes.Float64:
return pl.Float64()
if dtype == dtypes.Float32:
return pl.Float32()
if dtype == dtypes.Int128 and hasattr(pl, "Int128"):
# Not available for Polars pre 1.8.0
return pl.Int128()
if dtype == dtypes.Int64:
return pl.Int64()
if dtype == dtypes.Int32:
return pl.Int32()
if dtype == dtypes.Int16:
return pl.Int16()
if dtype == dtypes.Int8:
return pl.Int8()
if dtype == dtypes.UInt64:
return pl.UInt64()
if dtype == dtypes.UInt32:
return pl.UInt32()
if dtype == dtypes.UInt16:
return pl.UInt16()
if dtype == dtypes.UInt8:
return pl.UInt8()
if dtype == dtypes.String:
return pl.String()
if dtype == dtypes.Boolean:
return pl.Boolean()
if dtype == dtypes.Object: # pragma: no cover
return pl.Object()
if dtype == dtypes.Categorical:
return pl.Categorical()
if isinstance_or_issubclass(dtype, dtypes.Enum):
if version is Version.V1:
msg = "Converting to Enum is not supported in narwhals.stable.v1"
raise NotImplementedError(msg)
if isinstance(dtype, dtypes.Enum):
return pl.Enum(dtype.categories)
msg = "Can not cast / initialize Enum without categories present"
raise ValueError(msg)
if dtype == dtypes.Date:
return pl.Date()
if dtype == dtypes.Time:
return pl.Time()
if dtype == dtypes.Binary:
return pl.Binary()
if dtype == dtypes.Decimal:
msg = "Casting to Decimal is not supported yet."
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Datetime):
return pl.Datetime(dtype.time_unit, dtype.time_zone) # type: ignore[arg-type]
if isinstance_or_issubclass(dtype, dtypes.Duration):
return pl.Duration(dtype.time_unit) # type: ignore[arg-type]
if isinstance_or_issubclass(dtype, dtypes.List):
return pl.List(narwhals_to_native_dtype(dtype.inner, version, backend_version))
if isinstance_or_issubclass(dtype, dtypes.Struct):
fields = [
pl.Field(
field.name,
narwhals_to_native_dtype(field.dtype, version, backend_version),
)
for field in dtype.fields
]
return pl.Struct(fields)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
size = dtype.size
kwargs = {"width": size} if backend_version < (0, 20, 30) else {"shape": size}
return pl.Array(
narwhals_to_native_dtype(dtype.inner, version, backend_version), **kwargs
)
return pl.Unknown() # pragma: no cover
def catch_polars_exception(
exception: Exception, backend_version: tuple[int, ...]
) -> NarwhalsError | Exception:
if isinstance(exception, pl.exceptions.ColumnNotFoundError):
return ColumnNotFoundError(str(exception))
elif isinstance(exception, pl.exceptions.ShapeError):
return ShapeError(str(exception))
elif isinstance(exception, pl.exceptions.InvalidOperationError):
return InvalidOperationError(str(exception))
elif isinstance(exception, pl.exceptions.DuplicateError):
return DuplicateError(str(exception))
elif isinstance(exception, pl.exceptions.ComputeError):
return ComputeError(str(exception))
if backend_version >= (1,) and isinstance(exception, pl.exceptions.PolarsError):
# Old versions of Polars didn't have PolarsError.
return NarwhalsError(str(exception)) # pragma: no cover
elif backend_version < (1,) and "polars.exceptions" in str(
type(exception)
): # pragma: no cover
# Last attempt, for old Polars versions.
return NarwhalsError(str(exception))
# Just return exception as-is.
return exception