import scvi
import mcp.types as types
from ..schema.scvi import (
# SCVI
SCVISetupModel,
SCVICreateModel,
SCVITrainModel,
SCVIGetLatentModel,
SCVIGetNormalizedModel,
SCVIDifferentialExpressionModel,
SCVISaveModel,
SCVILoadModel,
# SCANVI
SCANVISetupModel,
SCANVICreateModel,
SCANVIFromSCVIModel,
SCANVIPredictModel,
# TOTALVI
TOTALVISetupModel,
TOTALVICreateModel,
TOTALVIGetProteinForegroundProbModel,
# PEAKVI
PEAKVISetupModel,
PEAKVICreateModel,
PEAKVIDifferentialAccessibilityModel,
# Common
GetELBOModel,
GetReconstructionErrorModel,
)
# ==================== SCVI Tools ====================
scvi_setup_anndata_tool = types.Tool(
name="scvi_setup_anndata",
description="Setup AnnData for SCVI model. This prepares the data for training.",
inputSchema=SCVISetupModel.model_json_schema(),
)
scvi_create_model_tool = types.Tool(
name="scvi_create_model",
description="Create a new SCVI model for single-cell RNA-seq analysis.",
inputSchema=SCVICreateModel.model_json_schema(),
)
scvi_train_model_tool = types.Tool(
name="scvi_train_model",
description="Train the SCVI model. This learns the latent representation.",
inputSchema=SCVITrainModel.model_json_schema(),
)
scvi_get_latent_representation_tool = types.Tool(
name="scvi_get_latent_representation",
description="Extract latent representation from trained SCVI model. "
"This is typically used for visualization and clustering.",
inputSchema=SCVIGetLatentModel.model_json_schema(),
)
scvi_get_normalized_expression_tool = types.Tool(
name="scvi_get_normalized_expression",
description="Get normalized gene expression from SCVI model.",
inputSchema=SCVIGetNormalizedModel.model_json_schema(),
)
scvi_differential_expression_tool = types.Tool(
name="scvi_differential_expression",
description="Perform differential expression analysis between groups using SCVI.",
inputSchema=SCVIDifferentialExpressionModel.model_json_schema(),
)
scvi_save_model_tool = types.Tool(
name="scvi_save_model",
description="Save SCVI model to disk for later use.",
inputSchema=SCVISaveModel.model_json_schema(),
)
scvi_load_model_tool = types.Tool(
name="scvi_load_model",
description="Load a previously saved SCVI model from disk.",
inputSchema=SCVILoadModel.model_json_schema(),
)
# ==================== SCANVI Tools ====================
scanvi_setup_anndata_tool = types.Tool(
name="scanvi_setup_anndata",
description="Setup AnnData for SCANVI model (semi-supervised cell type annotation).",
inputSchema=SCANVISetupModel.model_json_schema(),
)
scanvi_create_model_tool = types.Tool(
name="scanvi_create_model",
description="Create SCANVI model for semi-supervised cell type annotation.",
inputSchema=SCANVICreateModel.model_json_schema(),
)
scanvi_from_scvi_model_tool = types.Tool(
name="scanvi_from_scvi_model",
description="Create SCANVI model from a pre-trained SCVI model.",
inputSchema=SCANVIFromSCVIModel.model_json_schema(),
)
scanvi_predict_tool = types.Tool(
name="scanvi_predict",
description="Predict cell types using trained SCANVI model.",
inputSchema=SCANVIPredictModel.model_json_schema(),
)
# ==================== TOTALVI Tools ====================
totalvi_setup_anndata_tool = types.Tool(
name="totalvi_setup_anndata",
description="Setup AnnData for TOTALVI model (CITE-seq: RNA + Protein).",
inputSchema=TOTALVISetupModel.model_json_schema(),
)
totalvi_create_model_tool = types.Tool(
name="totalvi_create_model",
description="Create TOTALVI model for joint RNA and protein analysis.",
inputSchema=TOTALVICreateModel.model_json_schema(),
)
totalvi_get_protein_foreground_prob_tool = types.Tool(
name="totalvi_get_protein_foreground_prob",
description="Get protein foreground probability from TOTALVI.",
inputSchema=TOTALVIGetProteinForegroundProbModel.model_json_schema(),
)
# ==================== PEAKVI Tools ====================
peakvi_setup_anndata_tool = types.Tool(
name="peakvi_setup_anndata",
description="Setup AnnData for PEAKVI model (scATAC-seq).",
inputSchema=PEAKVISetupModel.model_json_schema(),
)
peakvi_create_model_tool = types.Tool(
name="peakvi_create_model",
description="Create PEAKVI model for scATAC-seq analysis.",
inputSchema=PEAKVICreateModel.model_json_schema(),
)
peakvi_differential_accessibility_tool = types.Tool(
name="peakvi_differential_accessibility",
description="Perform differential accessibility analysis with PEAKVI.",
inputSchema=PEAKVIDifferentialAccessibilityModel.model_json_schema(),
)
# ==================== Common Tools ====================
scvi_get_elbo_tool = types.Tool(
name="scvi_get_elbo",
description="Get ELBO (Evidence Lower Bound) from the model.",
inputSchema=GetELBOModel.model_json_schema(),
)
scvi_get_reconstruction_error_tool = types.Tool(
name="scvi_get_reconstruction_error",
description="Get reconstruction error from the model.",
inputSchema=GetReconstructionErrorModel.model_json_schema(),
)
# ==================== Tool Registry ====================
scvi_tools = {
# SCVI
"scvi_setup_anndata": scvi_setup_anndata_tool,
"scvi_create_model": scvi_create_model_tool,
"scvi_train_model": scvi_train_model_tool,
"scvi_get_latent_representation": scvi_get_latent_representation_tool,
"scvi_get_normalized_expression": scvi_get_normalized_expression_tool,
"scvi_differential_expression": scvi_differential_expression_tool,
"scvi_save_model": scvi_save_model_tool,
"scvi_load_model": scvi_load_model_tool,
# SCANVI
"scanvi_setup_anndata": scanvi_setup_anndata_tool,
"scanvi_create_model": scanvi_create_model_tool,
"scanvi_from_scvi_model": scanvi_from_scvi_model_tool,
"scanvi_predict": scanvi_predict_tool,
# TOTALVI
"totalvi_setup_anndata": totalvi_setup_anndata_tool,
"totalvi_create_model": totalvi_create_model_tool,
"totalvi_get_protein_foreground_prob": totalvi_get_protein_foreground_prob_tool,
# PEAKVI
"peakvi_setup_anndata": peakvi_setup_anndata_tool,
"peakvi_create_model": peakvi_create_model_tool,
"peakvi_differential_accessibility": peakvi_differential_accessibility_tool,
# Common
"scvi_get_elbo": scvi_get_elbo_tool,
"scvi_get_reconstruction_error": scvi_get_reconstruction_error_tool,
}
# ==================== Tool Execution Function ====================
def run_scvi_func(state, func_name: str, arguments: dict):
"""
Execute scvi-tools functions.
Parameters:
state: Model state containing adata and models
func_name: Name of the tool to execute
arguments: Tool arguments
Returns:
Result message or data
"""
adata = state.adata_dic[state.active]
# ==================== SCVI Functions ====================
if func_name == "scvi_setup_anndata":
params = SCVISetupModel(**arguments)
scvi.model.SCVI.setup_anndata(
adata,
layer=params.layer,
batch_key=params.batch_key,
labels_key=params.labels_key,
categorical_covariate_keys=params.categorical_covariate_keys,
continuous_covariate_keys=params.continuous_covariate_keys,
)
return "SCVI AnnData setup complete. Ready to create model."
elif func_name == "scvi_create_model":
params = SCVICreateModel(**arguments)
model = scvi.model.SCVI(
adata,
n_hidden=params.n_hidden,
n_latent=params.n_latent,
n_layers=params.n_layers,
dropout_rate=params.dropout_rate,
gene_likelihood=params.gene_likelihood,
)
state.scvi_model = model
return f"SCVI model created with {params.n_latent} latent dimensions."
elif func_name == "scvi_train_model":
params = SCVITrainModel(**arguments)
if not hasattr(state, 'scvi_model') or state.scvi_model is None:
return "Error: No SCVI model found. Create a model first."
state.scvi_model.train(
max_epochs=params.max_epochs,
batch_size=params.batch_size,
early_stopping=params.early_stopping,
early_stopping_patience=params.early_stopping_patience,
)
return "SCVI model training complete."
elif func_name == "scvi_get_latent_representation":
params = SCVIGetLatentModel(**arguments)
if not hasattr(state, 'scvi_model') or state.scvi_model is None:
return "Error: No SCVI model found."
latent = state.scvi_model.get_latent_representation(give_mean=params.give_mean)
adata.obsm[params.save_key] = latent
return f"Latent representation saved to adata.obsm['{params.save_key}']"
elif func_name == "scvi_get_normalized_expression":
params = SCVIGetNormalizedModel(**arguments)
if not hasattr(state, 'scvi_model') or state.scvi_model is None:
return "Error: No SCVI model found."
normalized = state.scvi_model.get_normalized_expression(library_size=params.library_size)
adata.layers[params.save_layer] = normalized
return f"Normalized expression saved to adata.layers['{params.save_layer}']"
elif func_name == "scvi_differential_expression":
params = SCVIDifferentialExpressionModel(**arguments)
if not hasattr(state, 'scvi_model') or state.scvi_model is None:
return "Error: No SCVI model found."
de_df = state.scvi_model.differential_expression(
groupby=params.groupby,
group1=params.group1,
group2=params.group2,
mode=params.mode,
delta=params.delta,
batch_correction=params.batch_correction,
)
# Store results
state.de_results = de_df
return f"Differential expression complete. Found {len(de_df)} genes. Results stored in state.de_results"
elif func_name == "scvi_save_model":
params = SCVISaveModel(**arguments)
if not hasattr(state, 'scvi_model') or state.scvi_model is None:
return "Error: No SCVI model found."
state.scvi_model.save(params.dir_path, overwrite=params.overwrite)
return f"Model saved to {params.dir_path}"
elif func_name == "scvi_load_model":
params = SCVILoadModel(**arguments)
state.scvi_model = scvi.model.SCVI.load(params.dir_path, adata=adata)
return f"Model loaded from {params.dir_path}"
# ==================== SCANVI Functions ====================
elif func_name == "scanvi_setup_anndata":
params = SCANVISetupModel(**arguments)
scvi.model.SCANVI.setup_anndata(
adata,
layer=params.layer,
batch_key=params.batch_key,
labels_key=params.labels_key,
unlabeled_category=params.unlabeled_category,
categorical_covariate_keys=params.categorical_covariate_keys,
continuous_covariate_keys=params.continuous_covariate_keys,
)
return "SCANVI AnnData setup complete."
elif func_name == "scanvi_create_model":
params = SCANVICreateModel(**arguments)
model = scvi.model.SCANVI(
adata,
unlabeled_category=params.unlabeled_category,
n_hidden=params.n_hidden,
n_latent=params.n_latent,
n_layers=params.n_layers,
dropout_rate=params.dropout_rate,
gene_likelihood=params.gene_likelihood,
)
state.scanvi_model = model
return f"SCANVI model created with {params.n_latent} latent dimensions."
elif func_name == "scanvi_from_scvi_model":
params = SCANVIFromSCVIModel(**arguments)
if not hasattr(state, 'scvi_model') or state.scvi_model is None:
return "Error: No SCVI model found."
model = scvi.model.SCANVI.from_scvi_model(
state.scvi_model,
unlabeled_category=params.unlabeled_category,
adata=adata,
)
state.scanvi_model = model
return "SCANVI model created from SCVI model."
elif func_name == "scanvi_predict":
params = SCANVIPredictModel(**arguments)
if not hasattr(state, 'scanvi_model') or state.scanvi_model is None:
return "Error: No SCANVI model found."
predictions = state.scanvi_model.predict(soft=params.soft)
adata.obs[params.save_key] = predictions
return f"Predictions saved to adata.obs['{params.save_key}']"
# ==================== TOTALVI Functions ====================
elif func_name == "totalvi_setup_anndata":
params = TOTALVISetupModel(**arguments)
scvi.model.TOTALVI.setup_anndata(
adata,
protein_expression_obsm_key=params.protein_expression_obsm_key,
layer=params.layer,
batch_key=params.batch_key,
categorical_covariate_keys=params.categorical_covariate_keys,
continuous_covariate_keys=params.continuous_covariate_keys,
)
return "TOTALVI AnnData setup complete."
elif func_name == "totalvi_create_model":
params = TOTALVICreateModel(**arguments)
model = scvi.model.TOTALVI(
adata,
n_latent=params.n_latent,
gene_likelihood=params.gene_likelihood,
)
state.totalvi_model = model
return f"TOTALVI model created with {params.n_latent} latent dimensions."
elif func_name == "totalvi_get_protein_foreground_prob":
params = TOTALVIGetProteinForegroundProbModel(**arguments)
if not hasattr(state, 'totalvi_model') or state.totalvi_model is None:
return "Error: No TOTALVI model found."
prob = state.totalvi_model.get_protein_foreground_probability()
adata.obsm[params.save_key] = prob
return f"Protein foreground probability saved to adata.obsm['{params.save_key}']"
# ==================== PEAKVI Functions ====================
elif func_name == "peakvi_setup_anndata":
params = PEAKVISetupModel(**arguments)
scvi.model.PEAKVI.setup_anndata(
adata,
batch_key=params.batch_key,
layer=params.layer,
)
return "PEAKVI AnnData setup complete."
elif func_name == "peakvi_create_model":
params = PEAKVICreateModel(**arguments)
model = scvi.model.PEAKVI(
adata,
n_hidden=params.n_hidden,
n_latent=params.n_latent,
n_layers_encoder=params.n_layers_encoder,
n_layers_decoder=params.n_layers_decoder,
)
state.peakvi_model = model
return f"PEAKVI model created with {params.n_latent} latent dimensions."
elif func_name == "peakvi_differential_accessibility":
params = PEAKVIDifferentialAccessibilityModel(**arguments)
if not hasattr(state, 'peakvi_model') or state.peakvi_model is None:
return "Error: No PEAKVI model found."
da_df = state.peakvi_model.differential_accessibility(
groupby=params.groupby,
group1=params.group1,
group2=params.group2,
)
state.da_results = da_df
return f"Differential accessibility complete. Results stored in state.da_results"
# ==================== Common Functions ====================
elif func_name == "scvi_get_elbo":
# Try to find any available model
model = None
if hasattr(state, 'scvi_model') and state.scvi_model is not None:
model = state.scvi_model
elif hasattr(state, 'scanvi_model') and state.scanvi_model is not None:
model = state.scanvi_model
elif hasattr(state, 'totalvi_model') and state.totalvi_model is not None:
model = state.totalvi_model
elif hasattr(state, 'peakvi_model') and state.peakvi_model is not None:
model = state.peakvi_model
if model is None:
return "Error: No model found."
elbo = model.get_elbo()
return f"ELBO: {elbo}"
elif func_name == "scvi_get_reconstruction_error":
# Try to find any available model
model = None
if hasattr(state, 'scvi_model') and state.scvi_model is not None:
model = state.scvi_model
elif hasattr(state, 'scanvi_model') and state.scanvi_model is not None:
model = state.scanvi_model
elif hasattr(state, 'totalvi_model') and state.totalvi_model is not None:
model = state.totalvi_model
elif hasattr(state, 'peakvi_model') and state.peakvi_model is not None:
model = state.peakvi_model
if model is None:
return "Error: No model found."
recon_error = model.get_reconstruction_error()
return f"Reconstruction error: {recon_error}"
else:
raise ValueError(f"Unsupported function in 'scvi' module: {func_name}")