from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar, cast
from ..config import registry
from ..model import Model
from ..types import ArrayXd, ListXd
ItemT = TypeVar("ItemT")
InT = Sequence[Sequence[ItemT]]
OutT = ListXd
InnerInT = Sequence[ItemT]
InnerOutT = ArrayXd
@registry.layers("with_flatten.v1")
def with_flatten(layer: Model[InnerInT[ItemT], InnerOutT]) -> Model[InT[ItemT], OutT]:
return Model(f"with_flatten({layer.name})", forward, layers=[layer], init=init)
def forward(
model: Model[InT, OutT], Xnest: InT, is_train: bool
) -> Tuple[OutT, Callable]:
layer: Model[InnerInT, InnerOutT] = model.layers[0]
Xflat = _flatten(Xnest)
Yflat, backprop_layer = layer(Xflat, is_train)
# Get the split points. We want n-1 splits for n items.
arr = layer.ops.asarray1i([len(x) for x in Xnest[:-1]])
splits = arr.cumsum()
Ynest = layer.ops.xp.split(Yflat, splits, axis=0)
def backprop(dYnest: OutT) -> InT:
dYflat = model.ops.flatten(dYnest) # type: ignore[arg-type, var-annotated]
# type ignore necessary for older versions of Mypy/Pydantic
dXflat = backprop_layer(dYflat)
dXnest = layer.ops.xp.split(dXflat, splits, axis=-1)
return dXnest
return Ynest, backprop
def _flatten(nested: InT) -> InnerInT:
flat: List = []
for item in nested:
flat.extend(item)
return cast(InT, flat)
def init(
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None
) -> None:
model.layers[0].initialize(
_flatten(X) if X is not None else None,
model.layers[0].ops.xp.hstack(Y) if Y is not None else None,
)