from pydantic import Field
from typing import Optional, List, Union, Dict, Any
from .base import JSONParsingModel
# ==================== Setup ====================
class SetupAnndataModel(JSONParsingModel):
"""Setup AnnData for scvi-tools models."""
layer: Optional[str] = Field(
default=None,
description="Key in adata.layers for count data. If None, uses adata.X"
)
batch_key: Optional[str] = Field(
default=None,
description="Key in adata.obs for batch information"
)
labels_key: Optional[str] = Field(
default=None,
description="Key in adata.obs for cell type labels (for SCANVI)"
)
categorical_covariate_keys: Optional[List[str]] = Field(
default=None,
description="Keys in adata.obs for categorical covariates"
)
continuous_covariate_keys: Optional[List[str]] = Field(
default=None,
description="Keys in adata.obs for continuous covariates"
)
# ==================== SCVI ====================
class SCVISetupModel(SetupAnndataModel):
"""Setup AnnData for SCVI model."""
pass
class SCVICreateModel(JSONParsingModel):
"""Create SCVI model."""
n_hidden: int = Field(
default=128,
description="Number of nodes per hidden layer"
)
n_latent: int = Field(
default=10,
description="Dimensionality of the latent space"
)
n_layers: int = Field(
default=1,
description="Number of hidden layers"
)
dropout_rate: float = Field(
default=0.1,
description="Dropout rate for neural networks"
)
gene_likelihood: str = Field(
default="zinb",
description="Gene likelihood: 'zinb', 'nb', 'poisson'"
)
class SCVITrainModel(JSONParsingModel):
"""Train SCVI model."""
max_epochs: int = Field(
default=400,
description="Maximum number of training epochs"
)
batch_size: int = Field(
default=128,
description="Minibatch size for training"
)
early_stopping: bool = Field(
default=True,
description="Enable early stopping"
)
early_stopping_patience: int = Field(
default=45,
description="Number of epochs with no improvement after which training stops"
)
class SCVIGetLatentModel(JSONParsingModel):
"""Get latent representation from SCVI."""
give_mean: bool = Field(
default=True,
description="Return mean of latent distribution"
)
save_key: str = Field(
default="X_scVI",
description="Key in adata.obsm to save latent representation"
)
class SCVIGetNormalizedModel(JSONParsingModel):
"""Get normalized expression from SCVI."""
library_size: float = Field(
default=1e4,
description="Library size to use for normalization"
)
save_layer: str = Field(
default="scvi_normalized",
description="Layer in adata.layers to save normalized expression"
)
class SCVIDifferentialExpressionModel(JSONParsingModel):
"""Perform differential expression analysis with SCVI."""
groupby: str = Field(
description="Key in adata.obs for grouping"
)
group1: Optional[List[str]] = Field(
default=None,
description="First group of cells. If None, all groups are compared"
)
group2: Optional[str] = Field(
default=None,
description="Second group to compare. If None, compare to rest"
)
mode: str = Field(
default="change",
description="Mode: 'vanilla' or 'change'"
)
delta: float = Field(
default=0.25,
description="Specific case of region inducing differential expression"
)
batch_correction: bool = Field(
default=False,
description="Correct for batch effects"
)
class SCVISaveModel(JSONParsingModel):
"""Save SCVI model."""
dir_path: str = Field(
description="Path to directory where to save the model"
)
overwrite: bool = Field(
default=False,
description="Overwrite existing directory"
)
class SCVILoadModel(JSONParsingModel):
"""Load SCVI model."""
dir_path: str = Field(
description="Path to saved model directory"
)
# ==================== SCANVI ====================
class SCANVISetupModel(SetupAnndataModel):
"""Setup AnnData for SCANVI model."""
unlabeled_category: str = Field(
description="Value used for unlabeled cells in labels_key"
)
class SCANVICreateModel(JSONParsingModel):
"""Create SCANVI model."""
unlabeled_category: str = Field(
description="Value used for unlabeled cells"
)
n_hidden: int = Field(
default=128,
description="Number of nodes per hidden layer"
)
n_latent: int = Field(
default=10,
description="Dimensionality of the latent space"
)
n_layers: int = Field(
default=1,
description="Number of hidden layers"
)
dropout_rate: float = Field(
default=0.1,
description="Dropout rate"
)
gene_likelihood: str = Field(
default="zinb",
description="Gene likelihood: 'zinb', 'nb', 'poisson'"
)
class SCANVIFromSCVIModel(JSONParsingModel):
"""Create SCANVI from pretrained SCVI."""
unlabeled_category: str = Field(
description="Value used for unlabeled cells"
)
class SCANVIPredictModel(JSONParsingModel):
"""Predict cell types with SCANVI."""
save_key: str = Field(
default="scanvi_predictions",
description="Key in adata.obs to save predictions"
)
soft: bool = Field(
default=False,
description="Return probabilities instead of predicted labels"
)
# ==================== TOTALVI ====================
class TOTALVISetupModel(JSONParsingModel):
"""Setup AnnData for TOTALVI model."""
protein_expression_obsm_key: str = Field(
description="Key in adata.obsm for protein expression"
)
layer: Optional[str] = Field(
default=None,
description="Layer for RNA counts"
)
batch_key: Optional[str] = Field(
default=None,
description="Batch key in adata.obs"
)
categorical_covariate_keys: Optional[List[str]] = Field(
default=None,
description="Categorical covariate keys"
)
continuous_covariate_keys: Optional[List[str]] = Field(
default=None,
description="Continuous covariate keys"
)
class TOTALVICreateModel(JSONParsingModel):
"""Create TOTALVI model."""
n_latent: int = Field(
default=20,
description="Dimensionality of latent space"
)
gene_likelihood: str = Field(
default="nb",
description="Gene likelihood: 'nb', 'zinb'"
)
class TOTALVIGetProteinForegroundProbModel(JSONParsingModel):
"""Get protein foreground probability."""
save_key: str = Field(
default="totalvi_protein_fg_prob",
description="Key to save in adata.obsm"
)
# ==================== PEAKVI ====================
class PEAKVISetupModel(JSONParsingModel):
"""Setup AnnData for PEAKVI model."""
batch_key: Optional[str] = Field(
default=None,
description="Batch key in adata.obs"
)
layer: Optional[str] = Field(
default=None,
description="Layer for counts. If None, use adata.X"
)
class PEAKVICreateModel(JSONParsingModel):
"""Create PEAKVI model."""
n_hidden: int = Field(
default=128,
description="Number of nodes per hidden layer"
)
n_latent: int = Field(
default=10,
description="Dimensionality of latent space"
)
n_layers_encoder: int = Field(
default=2,
description="Number of layers for encoder"
)
n_layers_decoder: int = Field(
default=2,
description="Number of layers for decoder"
)
class PEAKVIDifferentialAccessibilityModel(JSONParsingModel):
"""Differential accessibility analysis with PEAKVI."""
groupby: str = Field(
description="Key in adata.obs for grouping"
)
group1: Optional[List[str]] = Field(
default=None,
description="First group"
)
group2: Optional[str] = Field(
default=None,
description="Second group"
)
# ==================== Common Operations ====================
class GetELBOModel(JSONParsingModel):
"""Get ELBO (Evidence Lower Bound)."""
pass
class GetReconstructionErrorModel(JSONParsingModel):
"""Get reconstruction error."""
pass