parameter.pyā¢13.6 kB
import inspect
from collections.abc import Iterable
from copy import deepcopy
from typing import (
Any,
Callable,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
get_args,
get_origin,
)
import attrs
from attrs import define, field
import cyclopts._env_var
import cyclopts.utils
from cyclopts._convert import ITERABLE_TYPES
from cyclopts.annotations import NoneType, is_annotated, is_nonetype, is_union, resolve_annotated, resolve_optional
from cyclopts.field_info import signature_parameters
from cyclopts.group import Group
from cyclopts.token import Token
from cyclopts.utils import (
default_name_transform,
frozen,
optional_to_tuple_converter,
record_init,
to_tuple_converter,
)
ITERATIVE_BOOL_IMPLICIT_VALUE = frozenset(
{
Iterable[bool],
Sequence[bool],
List[bool],
list[bool],
Tuple[bool, ...],
tuple[bool, ...],
}
)
T = TypeVar("T")
_NEGATIVE_FLAG_TYPES = frozenset([bool, None, NoneType, *ITERABLE_TYPES, *ITERATIVE_BOOL_IMPLICIT_VALUE])
def _not_hyphen_validator(instance, attribute, values):
for value in values:
if value is not None and value.startswith("-"):
raise ValueError(f'{attribute.alias} value must NOT start with "-".')
def _negative_converter(default: tuple[str, ...]):
def converter(value) -> tuple[str, ...]:
if value is None:
return default
else:
return to_tuple_converter(value)
return converter
# TODO: Breaking change; all fields after ``name`` should be ``kw_only=True``.
@record_init("_provided_args")
@frozen
class Parameter:
"""Cyclopts configuration for individual function parameters with :obj:`~typing.Annotated`.
Example usage:
.. code-block:: python
from cyclopts import app, Parameter
from typing import Annotated
app = App()
@app.default
def main(foo: Annotated[int, Parameter(name="bar")]):
print(foo)
app()
.. code-block:: console
$ my-script 100
100
$ my-script --bar 100
100
"""
# All attribute docstrings has been moved to ``docs/api.rst`` for greater control with attrs.
# This can ONLY ever be a Tuple[str, ...]
# Usually starts with "--" or "-"
name: Union[None, str, Iterable[str]] = field(
default=None,
converter=lambda x: cast(tuple[str, ...], to_tuple_converter(x)),
)
converter: Optional[Callable[[Any, Sequence[Token]], Any]] = field(default=None)
# This can ONLY ever be a Tuple[Callable, ...]
validator: Union[None, Callable[[Any, Any], Any], Iterable[Callable[[Any, Any], Any]]] = field(
default=(),
converter=lambda x: cast(tuple[Callable[[Any, Any], Any], ...], to_tuple_converter(x)),
)
# This can ONLY ever be a Tuple[str, ...]
alias: Union[None, str, Iterable[str]] = field(
default=None,
converter=lambda x: cast(tuple[str, ...], to_tuple_converter(x)),
)
# This can ONLY ever be ``None`` or ``Tuple[str, ...]``
negative: Union[None, str, Iterable[str]] = field(
default=None,
converter=optional_to_tuple_converter,
kw_only=True,
)
# This can ONLY ever be a Tuple[Union[Group, str], ...]
group: Union[None, Group, str, Iterable[Union[Group, str]]] = field(
default=None,
converter=to_tuple_converter,
kw_only=True,
hash=False,
)
parse: bool = field(
default=None,
converter=attrs.converters.default_if_none(True),
kw_only=True,
)
_show: Optional[bool] = field(
default=None,
alias="show",
kw_only=True,
)
show_default: Union[None, bool, Callable[[Any], Any]] = field(
default=None,
kw_only=True,
)
show_choices: bool = field(
default=None,
converter=attrs.converters.default_if_none(True),
kw_only=True,
)
help: Optional[str] = field(default=None, kw_only=True)
show_env_var: bool = field(
default=None,
converter=attrs.converters.default_if_none(True),
kw_only=True,
)
# This can ONLY ever be a Tuple[str, ...]
env_var: Union[None, str, Iterable[str]] = field(
default=None,
converter=lambda x: cast(tuple[str, ...], to_tuple_converter(x)),
kw_only=True,
)
env_var_split: Callable = field(
default=cyclopts._env_var.env_var_split,
kw_only=True,
)
# This can ONLY ever be a Tuple[str, ...]
negative_bool: Union[None, str, Iterable[str]] = field(
default=None,
converter=_negative_converter(("no-",)),
validator=_not_hyphen_validator,
kw_only=True,
)
# This can ONLY ever be a Tuple[str, ...]
negative_iterable: Union[None, str, Iterable[str]] = field(
default=None,
converter=_negative_converter(("empty-",)),
validator=_not_hyphen_validator,
kw_only=True,
)
# This can ONLY ever be a Tuple[str, ...]
negative_none: Union[None, str, Iterable[str]] = field(
default=None,
converter=_negative_converter(()),
validator=_not_hyphen_validator,
kw_only=True,
)
required: Optional[bool] = field(
default=None,
kw_only=True,
)
allow_leading_hyphen: bool = field(
default=False,
kw_only=True,
)
_name_transform: Optional[Callable[[str], str]] = field(
alias="name_transform",
default=None,
kw_only=True,
)
accepts_keys: Optional[bool] = field(
default=None,
kw_only=True,
)
consume_multiple: bool = field(
default=None,
converter=attrs.converters.default_if_none(False),
kw_only=True,
)
json_dict: Optional[bool] = field(default=None, kw_only=True)
json_list: Optional[bool] = field(default=None, kw_only=True)
# Populated by the record_attrs_init_args decorator.
_provided_args: tuple[str] = field(factory=tuple, init=False, eq=False)
@property
def show(self) -> bool:
return self._show if self._show is not None else self.parse
@property
def name_transform(self):
return self._name_transform if self._name_transform else default_name_transform
def get_negatives(self, type_) -> tuple[str, ...]:
type_ = resolve_annotated(type_)
if is_union(type_):
out = []
for x in get_args(type_):
out.extend(self.get_negatives(x))
return tuple(out)
origin = get_origin(type_)
if type_ not in _NEGATIVE_FLAG_TYPES:
if origin:
if origin not in _NEGATIVE_FLAG_TYPES:
return ()
else:
return ()
out, user_negatives = [], []
if self.negative:
for negative in self.negative:
(out if negative.startswith("-") else user_negatives).append(negative)
if not user_negatives:
return tuple(out)
assert isinstance(self.name, tuple)
for name in self.name:
if not name.startswith("--"): # Only provide negation for option-like long flags.
continue
name = name[2:]
name_components = name.split(".")
if type_ is bool or type_ in ITERATIVE_BOOL_IMPLICIT_VALUE:
negative_prefixes = self.negative_bool
elif is_nonetype(type_) or type_ is None:
negative_prefixes = self.negative_none
else:
negative_prefixes = self.negative_iterable
name_prefix = ".".join(name_components[:-1])
if name_prefix:
name_prefix += "."
assert isinstance(negative_prefixes, tuple)
if self.negative is None:
for negative_prefix in negative_prefixes:
out.append(f"--{name_prefix}{negative_prefix}{name_components[-1]}")
else:
for negative in user_negatives:
out.append(f"--{name_prefix}{negative}")
return tuple(out)
def __repr__(self):
"""Only shows non-default values."""
content = ", ".join(
[
f"{a.alias}={getattr(self, a.name)!r}"
for a in self.__attrs_attrs__ # pyright: ignore[reportAttributeAccessIssue]
if a.alias in self._provided_args
]
)
return f"{type(self).__name__}({content})"
@classmethod
def combine(cls, *parameters: Optional["Parameter"]) -> "Parameter":
"""Returns a new Parameter with combined values of all provided ``parameters``.
Parameters
----------
`*parameters`: Optional[Parameter]
Parameters who's attributes override ``self`` attributes.
Ordered from least-to-highest attribute priority.
"""
kwargs = {}
filtered = [x for x in parameters if x is not None]
# In the common case of 0/1 parameters to combine, we can avoid
# instantiating a new Parameter object.
if len(filtered) == 1:
return filtered[0]
elif not filtered:
return EMPTY_PARAMETER
for parameter in filtered:
for alias in parameter._provided_args:
kwargs[alias] = getattr(parameter, _parameter_alias_to_name[alias])
return cls(**kwargs)
@classmethod
def default(cls) -> "Parameter":
"""Create a Parameter with all Cyclopts-default values.
This is different than just :class:`Parameter` because the default
values will be recorded and override all upstream parameter values.
"""
return cls(
**{a.alias: a.default for a in cls.__attrs_attrs__ if a.init} # pyright: ignore[reportAttributeAccessIssue]
)
@classmethod
def from_annotation(cls, type_: Any, *default_parameters: Optional["Parameter"]) -> tuple[Any, "Parameter"]:
"""Resolve the immediate Parameter from a type hint."""
if type_ is inspect.Parameter.empty:
if default_parameters:
return type_, cls.combine(*default_parameters)
else:
return type_, EMPTY_PARAMETER
else:
type_, parameters = get_parameters(type_)
return type_, cls.combine(*default_parameters, *parameters)
def __call__(self, obj: T) -> T:
"""Decorator interface for annotating a function/class with a :class:`Parameter`.
Most commonly used for directly configuring a class:
.. code-block:: python
@Parameter(...)
class Foo: ...
"""
if not hasattr(obj, "__cyclopts__"):
obj.__cyclopts__ = CycloptsConfig(obj=obj) # pyright: ignore[reportAttributeAccessIssue]
elif obj.__cyclopts__.obj != obj: # pyright: ignore[reportAttributeAccessIssue]
# Create a copy so that children class Parameter decorators don't impact the parent.
obj.__cyclopts__ = deepcopy(obj.__cyclopts__) # pyright: ignore[reportAttributeAccessIssue]
obj.__cyclopts__.parameters.append(self) # pyright: ignore[reportAttributeAccessIssue]
return obj
_parameter_alias_to_name = {
p.alias: p.name
for p in Parameter.__attrs_attrs__ # pyright: ignore[reportAttributeAccessIssue]
if p.init
}
EMPTY_PARAMETER = Parameter()
def validate_command(f: Callable):
"""Validate if a function abides by Cyclopts's rules.
Raises
------
ValueError
Function has naming or parameter/signature inconsistencies.
"""
# python3.9 functools.partial does not have "__module__" attribute.
# TODO: simplify to (f.__module__ or "") once cp3.9 is dropped.
if (getattr(f, "__module__", "") or "").startswith("cyclopts"): # Speed optimization.
return
for field_info in signature_parameters(f).values():
# Speed optimization: if an object is not annotated, then there's nothing
# to validate. Checking if there's an annotation is significantly faster
# than instantiating a cyclopts.Parameter object.
if not is_annotated(field_info.annotation):
continue
_, cparam = Parameter.from_annotation(field_info.annotation)
if not cparam.parse and field_info.kind is not field_info.KEYWORD_ONLY:
raise ValueError("Parameter.parse=False must be used with a KEYWORD_ONLY function parameter.")
def get_parameters(hint: T) -> tuple[T, list[Parameter]]:
"""At root level, checks for cyclopts.Parameter annotations.
Includes checking the ``__cyclopts__`` attribute.
Returns
-------
hint
Annotation hint with :obj:`Annotated` and :obj:`Optional` resolved.
list[Parameter]
List of parameters discovered.
"""
parameters = []
hint = resolve_optional(hint)
if cyclopts_config := getattr(hint, "__cyclopts__", None):
parameters.extend(cyclopts_config.parameters)
if is_annotated(hint):
inner = get_args(hint)
hint = inner[0]
parameters.extend(x for x in inner[1:] if isinstance(x, Parameter))
return hint, parameters
@define
class CycloptsConfig:
"""
Intended for storing additional data to a ``__cyclopts__`` attribute via decoration.
"""
obj: Any = None
parameters: list[Parameter] = field(factory=list, init=False)