Skip to main content
Glama
scvi.py17.3 kB
import scvi import mcp.types as types from ..schema.scvi import ( # SCVI SCVISetupModel, SCVICreateModel, SCVITrainModel, SCVIGetLatentModel, SCVIGetNormalizedModel, SCVIDifferentialExpressionModel, SCVISaveModel, SCVILoadModel, # SCANVI SCANVISetupModel, SCANVICreateModel, SCANVIFromSCVIModel, SCANVIPredictModel, # TOTALVI TOTALVISetupModel, TOTALVICreateModel, TOTALVIGetProteinForegroundProbModel, # PEAKVI PEAKVISetupModel, PEAKVICreateModel, PEAKVIDifferentialAccessibilityModel, # Common GetELBOModel, GetReconstructionErrorModel, ) # ==================== SCVI Tools ==================== scvi_setup_anndata_tool = types.Tool( name="scvi_setup_anndata", description="Setup AnnData for SCVI model. This prepares the data for training.", inputSchema=SCVISetupModel.model_json_schema(), ) scvi_create_model_tool = types.Tool( name="scvi_create_model", description="Create a new SCVI model for single-cell RNA-seq analysis.", inputSchema=SCVICreateModel.model_json_schema(), ) scvi_train_model_tool = types.Tool( name="scvi_train_model", description="Train the SCVI model. This learns the latent representation.", inputSchema=SCVITrainModel.model_json_schema(), ) scvi_get_latent_representation_tool = types.Tool( name="scvi_get_latent_representation", description="Extract latent representation from trained SCVI model. " "This is typically used for visualization and clustering.", inputSchema=SCVIGetLatentModel.model_json_schema(), ) scvi_get_normalized_expression_tool = types.Tool( name="scvi_get_normalized_expression", description="Get normalized gene expression from SCVI model.", inputSchema=SCVIGetNormalizedModel.model_json_schema(), ) scvi_differential_expression_tool = types.Tool( name="scvi_differential_expression", description="Perform differential expression analysis between groups using SCVI.", inputSchema=SCVIDifferentialExpressionModel.model_json_schema(), ) scvi_save_model_tool = types.Tool( name="scvi_save_model", description="Save SCVI model to disk for later use.", inputSchema=SCVISaveModel.model_json_schema(), ) scvi_load_model_tool = types.Tool( name="scvi_load_model", description="Load a previously saved SCVI model from disk.", inputSchema=SCVILoadModel.model_json_schema(), ) # ==================== SCANVI Tools ==================== scanvi_setup_anndata_tool = types.Tool( name="scanvi_setup_anndata", description="Setup AnnData for SCANVI model (semi-supervised cell type annotation).", inputSchema=SCANVISetupModel.model_json_schema(), ) scanvi_create_model_tool = types.Tool( name="scanvi_create_model", description="Create SCANVI model for semi-supervised cell type annotation.", inputSchema=SCANVICreateModel.model_json_schema(), ) scanvi_from_scvi_model_tool = types.Tool( name="scanvi_from_scvi_model", description="Create SCANVI model from a pre-trained SCVI model.", inputSchema=SCANVIFromSCVIModel.model_json_schema(), ) scanvi_predict_tool = types.Tool( name="scanvi_predict", description="Predict cell types using trained SCANVI model.", inputSchema=SCANVIPredictModel.model_json_schema(), ) # ==================== TOTALVI Tools ==================== totalvi_setup_anndata_tool = types.Tool( name="totalvi_setup_anndata", description="Setup AnnData for TOTALVI model (CITE-seq: RNA + Protein).", inputSchema=TOTALVISetupModel.model_json_schema(), ) totalvi_create_model_tool = types.Tool( name="totalvi_create_model", description="Create TOTALVI model for joint RNA and protein analysis.", inputSchema=TOTALVICreateModel.model_json_schema(), ) totalvi_get_protein_foreground_prob_tool = types.Tool( name="totalvi_get_protein_foreground_prob", description="Get protein foreground probability from TOTALVI.", inputSchema=TOTALVIGetProteinForegroundProbModel.model_json_schema(), ) # ==================== PEAKVI Tools ==================== peakvi_setup_anndata_tool = types.Tool( name="peakvi_setup_anndata", description="Setup AnnData for PEAKVI model (scATAC-seq).", inputSchema=PEAKVISetupModel.model_json_schema(), ) peakvi_create_model_tool = types.Tool( name="peakvi_create_model", description="Create PEAKVI model for scATAC-seq analysis.", inputSchema=PEAKVICreateModel.model_json_schema(), ) peakvi_differential_accessibility_tool = types.Tool( name="peakvi_differential_accessibility", description="Perform differential accessibility analysis with PEAKVI.", inputSchema=PEAKVIDifferentialAccessibilityModel.model_json_schema(), ) # ==================== Common Tools ==================== scvi_get_elbo_tool = types.Tool( name="scvi_get_elbo", description="Get ELBO (Evidence Lower Bound) from the model.", inputSchema=GetELBOModel.model_json_schema(), ) scvi_get_reconstruction_error_tool = types.Tool( name="scvi_get_reconstruction_error", description="Get reconstruction error from the model.", inputSchema=GetReconstructionErrorModel.model_json_schema(), ) # ==================== Tool Registry ==================== scvi_tools = { # SCVI "scvi_setup_anndata": scvi_setup_anndata_tool, "scvi_create_model": scvi_create_model_tool, "scvi_train_model": scvi_train_model_tool, "scvi_get_latent_representation": scvi_get_latent_representation_tool, "scvi_get_normalized_expression": scvi_get_normalized_expression_tool, "scvi_differential_expression": scvi_differential_expression_tool, "scvi_save_model": scvi_save_model_tool, "scvi_load_model": scvi_load_model_tool, # SCANVI "scanvi_setup_anndata": scanvi_setup_anndata_tool, "scanvi_create_model": scanvi_create_model_tool, "scanvi_from_scvi_model": scanvi_from_scvi_model_tool, "scanvi_predict": scanvi_predict_tool, # TOTALVI "totalvi_setup_anndata": totalvi_setup_anndata_tool, "totalvi_create_model": totalvi_create_model_tool, "totalvi_get_protein_foreground_prob": totalvi_get_protein_foreground_prob_tool, # PEAKVI "peakvi_setup_anndata": peakvi_setup_anndata_tool, "peakvi_create_model": peakvi_create_model_tool, "peakvi_differential_accessibility": peakvi_differential_accessibility_tool, # Common "scvi_get_elbo": scvi_get_elbo_tool, "scvi_get_reconstruction_error": scvi_get_reconstruction_error_tool, } # ==================== Tool Execution Function ==================== def run_scvi_func(state, func_name: str, arguments: dict): """ Execute scvi-tools functions. Parameters: state: Model state containing adata and models func_name: Name of the tool to execute arguments: Tool arguments Returns: Result message or data """ adata = state.adata_dic[state.active] # ==================== SCVI Functions ==================== if func_name == "scvi_setup_anndata": params = SCVISetupModel(**arguments) scvi.model.SCVI.setup_anndata( adata, layer=params.layer, batch_key=params.batch_key, labels_key=params.labels_key, categorical_covariate_keys=params.categorical_covariate_keys, continuous_covariate_keys=params.continuous_covariate_keys, ) return "SCVI AnnData setup complete. Ready to create model." elif func_name == "scvi_create_model": params = SCVICreateModel(**arguments) model = scvi.model.SCVI( adata, n_hidden=params.n_hidden, n_latent=params.n_latent, n_layers=params.n_layers, dropout_rate=params.dropout_rate, gene_likelihood=params.gene_likelihood, ) state.scvi_model = model return f"SCVI model created with {params.n_latent} latent dimensions." elif func_name == "scvi_train_model": params = SCVITrainModel(**arguments) if not hasattr(state, 'scvi_model') or state.scvi_model is None: return "Error: No SCVI model found. Create a model first." state.scvi_model.train( max_epochs=params.max_epochs, batch_size=params.batch_size, early_stopping=params.early_stopping, early_stopping_patience=params.early_stopping_patience, ) return "SCVI model training complete." elif func_name == "scvi_get_latent_representation": params = SCVIGetLatentModel(**arguments) if not hasattr(state, 'scvi_model') or state.scvi_model is None: return "Error: No SCVI model found." latent = state.scvi_model.get_latent_representation(give_mean=params.give_mean) adata.obsm[params.save_key] = latent return f"Latent representation saved to adata.obsm['{params.save_key}']" elif func_name == "scvi_get_normalized_expression": params = SCVIGetNormalizedModel(**arguments) if not hasattr(state, 'scvi_model') or state.scvi_model is None: return "Error: No SCVI model found." normalized = state.scvi_model.get_normalized_expression(library_size=params.library_size) adata.layers[params.save_layer] = normalized return f"Normalized expression saved to adata.layers['{params.save_layer}']" elif func_name == "scvi_differential_expression": params = SCVIDifferentialExpressionModel(**arguments) if not hasattr(state, 'scvi_model') or state.scvi_model is None: return "Error: No SCVI model found." de_df = state.scvi_model.differential_expression( groupby=params.groupby, group1=params.group1, group2=params.group2, mode=params.mode, delta=params.delta, batch_correction=params.batch_correction, ) # Store results state.de_results = de_df return f"Differential expression complete. Found {len(de_df)} genes. Results stored in state.de_results" elif func_name == "scvi_save_model": params = SCVISaveModel(**arguments) if not hasattr(state, 'scvi_model') or state.scvi_model is None: return "Error: No SCVI model found." state.scvi_model.save(params.dir_path, overwrite=params.overwrite) return f"Model saved to {params.dir_path}" elif func_name == "scvi_load_model": params = SCVILoadModel(**arguments) state.scvi_model = scvi.model.SCVI.load(params.dir_path, adata=adata) return f"Model loaded from {params.dir_path}" # ==================== SCANVI Functions ==================== elif func_name == "scanvi_setup_anndata": params = SCANVISetupModel(**arguments) scvi.model.SCANVI.setup_anndata( adata, layer=params.layer, batch_key=params.batch_key, labels_key=params.labels_key, unlabeled_category=params.unlabeled_category, categorical_covariate_keys=params.categorical_covariate_keys, continuous_covariate_keys=params.continuous_covariate_keys, ) return "SCANVI AnnData setup complete." elif func_name == "scanvi_create_model": params = SCANVICreateModel(**arguments) model = scvi.model.SCANVI( adata, unlabeled_category=params.unlabeled_category, n_hidden=params.n_hidden, n_latent=params.n_latent, n_layers=params.n_layers, dropout_rate=params.dropout_rate, gene_likelihood=params.gene_likelihood, ) state.scanvi_model = model return f"SCANVI model created with {params.n_latent} latent dimensions." elif func_name == "scanvi_from_scvi_model": params = SCANVIFromSCVIModel(**arguments) if not hasattr(state, 'scvi_model') or state.scvi_model is None: return "Error: No SCVI model found." model = scvi.model.SCANVI.from_scvi_model( state.scvi_model, unlabeled_category=params.unlabeled_category, adata=adata, ) state.scanvi_model = model return "SCANVI model created from SCVI model." elif func_name == "scanvi_predict": params = SCANVIPredictModel(**arguments) if not hasattr(state, 'scanvi_model') or state.scanvi_model is None: return "Error: No SCANVI model found." predictions = state.scanvi_model.predict(soft=params.soft) adata.obs[params.save_key] = predictions return f"Predictions saved to adata.obs['{params.save_key}']" # ==================== TOTALVI Functions ==================== elif func_name == "totalvi_setup_anndata": params = TOTALVISetupModel(**arguments) scvi.model.TOTALVI.setup_anndata( adata, protein_expression_obsm_key=params.protein_expression_obsm_key, layer=params.layer, batch_key=params.batch_key, categorical_covariate_keys=params.categorical_covariate_keys, continuous_covariate_keys=params.continuous_covariate_keys, ) return "TOTALVI AnnData setup complete." elif func_name == "totalvi_create_model": params = TOTALVICreateModel(**arguments) model = scvi.model.TOTALVI( adata, n_latent=params.n_latent, gene_likelihood=params.gene_likelihood, ) state.totalvi_model = model return f"TOTALVI model created with {params.n_latent} latent dimensions." elif func_name == "totalvi_get_protein_foreground_prob": params = TOTALVIGetProteinForegroundProbModel(**arguments) if not hasattr(state, 'totalvi_model') or state.totalvi_model is None: return "Error: No TOTALVI model found." prob = state.totalvi_model.get_protein_foreground_probability() adata.obsm[params.save_key] = prob return f"Protein foreground probability saved to adata.obsm['{params.save_key}']" # ==================== PEAKVI Functions ==================== elif func_name == "peakvi_setup_anndata": params = PEAKVISetupModel(**arguments) scvi.model.PEAKVI.setup_anndata( adata, batch_key=params.batch_key, layer=params.layer, ) return "PEAKVI AnnData setup complete." elif func_name == "peakvi_create_model": params = PEAKVICreateModel(**arguments) model = scvi.model.PEAKVI( adata, n_hidden=params.n_hidden, n_latent=params.n_latent, n_layers_encoder=params.n_layers_encoder, n_layers_decoder=params.n_layers_decoder, ) state.peakvi_model = model return f"PEAKVI model created with {params.n_latent} latent dimensions." elif func_name == "peakvi_differential_accessibility": params = PEAKVIDifferentialAccessibilityModel(**arguments) if not hasattr(state, 'peakvi_model') or state.peakvi_model is None: return "Error: No PEAKVI model found." da_df = state.peakvi_model.differential_accessibility( groupby=params.groupby, group1=params.group1, group2=params.group2, ) state.da_results = da_df return f"Differential accessibility complete. Results stored in state.da_results" # ==================== Common Functions ==================== elif func_name == "scvi_get_elbo": # Try to find any available model model = None if hasattr(state, 'scvi_model') and state.scvi_model is not None: model = state.scvi_model elif hasattr(state, 'scanvi_model') and state.scanvi_model is not None: model = state.scanvi_model elif hasattr(state, 'totalvi_model') and state.totalvi_model is not None: model = state.totalvi_model elif hasattr(state, 'peakvi_model') and state.peakvi_model is not None: model = state.peakvi_model if model is None: return "Error: No model found." elbo = model.get_elbo() return f"ELBO: {elbo}" elif func_name == "scvi_get_reconstruction_error": # Try to find any available model model = None if hasattr(state, 'scvi_model') and state.scvi_model is not None: model = state.scvi_model elif hasattr(state, 'scanvi_model') and state.scanvi_model is not None: model = state.scanvi_model elif hasattr(state, 'totalvi_model') and state.totalvi_model is not None: model = state.totalvi_model elif hasattr(state, 'peakvi_model') and state.peakvi_model is not None: model = state.peakvi_model if model is None: return "Error: No model found." recon_error = model.get_reconstruction_error() return f"Reconstruction error: {recon_error}" else: raise ValueError(f"Unsupported function in 'scvi' module: {func_name}")

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/hyennnnnnn/scvi-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server