analyze_enrichment
Perform gene set enrichment analysis to identify overrepresented pathways or spatial patterns. Supports multiple methods and customizable parameters for species and gene sets.
Instructions
Perform gene set enrichment analysis.
Args:
data_id: Dataset ID
params: Required - species must be specified. See EnrichmentParameters for methods and gene_set_database options.Input Schema
| Name | Required | Description | Default |
|---|---|---|---|
| data_id | Yes | ||
| params | Yes |
Output Schema
| Name | Required | Description | Default |
|---|---|---|---|
| method | Yes | ||
| n_gene_sets | Yes | ||
| n_significant | Yes | ||
| top_gene_sets | Yes | ||
| top_depleted_sets | Yes | ||
| n_successful_signatures | No | ||
| spatial_scores_key | No | ||
| enrichment_scores | No | ||
| pvalues | No | ||
| adjusted_pvalues | No | ||
| gene_set_statistics | No | ||
| spatial_metrics | No |
Implementation Reference
- Unified entry point for analyze_enrichment. Loads gene sets (from database or custom), normalizes them, dispatches to the correct method (spatial_enrichmap, pathway_gsea, pathway_ora, pathway_ssgsea, pathway_enrichr), and returns an EnrichmentResult.
async def analyze_enrichment( data_id: str, ctx: "ToolContext", params: "EnrichmentParameters", ) -> EnrichmentResult: """ Unified entry point for gene set enrichment analysis. This function handles all enrichment methods with a consistent interface: - Gene set loading from databases - Method dispatch (GSEA, ORA, ssGSEA, Enrichr, spatial) - Error handling with clear messages Args: data_id: Dataset ID ctx: ToolContext for data access and logging params: EnrichmentParameters with method, species, database, etc. Returns: EnrichmentResult with enrichment scores and statistics Raises: ParameterError: If params is None or invalid ProcessingError: If gene set loading or analysis fails """ # Import here to avoid circular imports from ..utils.adata_utils import get_highly_variable_genes # Validate params if params is None: raise ParameterError( "params parameter is required for enrichment analysis.\n" "You must provide EnrichmentParameters with at least 'species' specified.\n" "Example: params={'species': 'mouse', 'method': 'pathway_ora'}" ) # Get adata adata = await ctx.get_adata(data_id) # Load gene sets loaded_from_database = False gene_sets = params.gene_sets if gene_sets is None and params.gene_set_database: try: gene_sets = load_gene_sets( database=params.gene_set_database, species=params.species, min_genes=params.min_genes, max_genes=params.max_genes, ctx=ctx, ) loaded_from_database = True except Exception as e: await ctx.error(f"Gene set database loading failed: {e}") raise ProcessingError( f"Failed to load gene sets from {params.gene_set_database}: {e}\n\n" f"SOLUTIONS:\n" f"1. Check your internet connection\n" f"2. Verify species parameter: '{params.species}'\n" f"3. Try a different database (KEGG_Pathways, GO_Biological_Process)\n" f"4. Provide custom gene sets via 'gene_sets' parameter" ) from e # Validate gene sets if gene_sets is None or len(gene_sets) == 0: raise ProcessingError( "No valid gene sets available. " "Please provide gene sets via 'gene_sets' parameter or " "specify a valid 'gene_set_database'." ) # Normalize gene_sets to dict format (convert list to single gene set dict) gene_sets_dict: dict[str, list[str]] if isinstance(gene_sets, list): gene_sets_dict = {"user_genes": gene_sets} else: gene_sets_dict = gene_sets if params.method == "spatial_enrichmap" and loaded_from_database: n_loaded_gene_sets = len(gene_sets_dict) gene_sets_dict = _limit_spatial_enrichmap_gene_sets( gene_sets_dict, available_genes=set(adata.var_names), species=params.species, ) if len(gene_sets_dict) < n_loaded_gene_sets: await ctx.warning( f"Spatial EnrichMap scored the {len(gene_sets_dict)} database gene sets " f"with highest dataset overlap out of {n_loaded_gene_sets} loaded sets. " "Provide custom gene_sets to score specific signatures." ) gene_sets = gene_sets_dict # Normalize score_keys to single string for methods that require it ranking_key: str | None = None if params.score_keys is not None: ranking_key = ( params.score_keys[0] if isinstance(params.score_keys, list) else params.score_keys ) # Dispatch to appropriate method if params.method == "spatial_enrichmap": result = await perform_spatial_enrichment( data_id=data_id, ctx=ctx, gene_sets=gene_sets, score_keys=params.score_keys, spatial_key=params.spatial_key, n_neighbors=params.n_neighbors, smoothing=params.smoothing, correct_spatial_covariates=params.correct_spatial_covariates, batch_key=params.batch_key, species=params.species, database=params.gene_set_database, ) elif params.method == "pathway_gsea": result = perform_gsea( adata=adata, gene_sets=gene_sets_dict, ranking_key=ranking_key, permutation_num=params.n_permutations, min_size=params.min_genes, max_size=params.max_genes, pvalue_cutoff=params.pvalue_cutoff, species=params.species, database=params.gene_set_database, ctx=ctx, data_id=data_id, ) elif params.method == "pathway_ora": result = perform_ora( adata=adata, gene_sets=gene_sets_dict, pvalue_threshold=params.pvalue_cutoff, min_size=params.min_genes, max_size=params.max_genes, adjust_method=params.adjust_method, species=params.species, database=params.gene_set_database, ctx=ctx, data_id=data_id, ) elif params.method == "pathway_ssgsea": result = perform_ssgsea( adata=adata, gene_sets=gene_sets_dict, min_size=params.min_genes, max_size=params.max_genes, species=params.species, database=params.gene_set_database, ctx=ctx, data_id=data_id, ) elif params.method == "pathway_enrichr": gene_list = get_highly_variable_genes(adata, max_genes=500) result = perform_enrichr( gene_list=gene_list, gene_sets=params.gene_set_database, organism=params.species, pvalue_cutoff=params.pvalue_cutoff, ctx=ctx, ) else: raise ParameterError(f"Unknown enrichment method: {params.method}") return result - perform_spatial_enrichment: Spatially-aware enrichment using EnrichMap scoring. Runs enrichment for each gene set with spatial smoothing/normalization, returns EnrichmentResult with per-signature statistics.
async def perform_spatial_enrichment( data_id: str, ctx: "ToolContext", gene_sets: Union[list[str], dict[str, list[str]]], score_keys: Optional[Union[str, list[str]]] = None, spatial_key: str = "spatial", n_neighbors: int = 6, smoothing: bool = True, correct_spatial_covariates: bool = True, batch_key: Optional[str] = None, species: str = "unknown", database: Optional[str] = None, ) -> "EnrichmentResult": """Perform spatially-aware gene set enrichment analysis using EnrichMap. Args: data_id: Identifier for the spatial data in the data store. ctx: MCP tool context for data access and logging. gene_sets: Either a single gene list or a dictionary of gene sets where keys are signature names and values are lists of genes. score_keys: Names for the gene signatures if gene_sets is a list. Ignored if gene_sets is already a dictionary. spatial_key: Key in adata.obsm containing spatial coordinates. n_neighbors: Number of nearest spatial neighbors for smoothing. smoothing: Whether to perform spatial smoothing. correct_spatial_covariates: Whether to correct for spatial covariates using GAM. batch_key: Column in adata.obs for batch-wise normalization. species: Species for the analysis (e.g., 'mouse', 'human'). database: Gene set database used (e.g., 'KEGG_Pathways', 'GO_Biological_Process'). Returns: EnrichmentResult containing enrichment scores and statistics. """ # Check if EnrichMap is available require("enrichmap", ctx, feature="spatial enrichment analysis") # Import EnrichMap import enrichmap as em # Get data using standard ctx pattern adata = await ctx.get_adata(data_id) # Validate spatial coordinates if spatial_key not in adata.obsm: raise ProcessingError( f"Spatial coordinates '{spatial_key}' not found in adata.obsm" ) # Convert single gene list to dictionary format gene_sets_dict: dict[str, list[str]] if isinstance(gene_sets, list): # For a single gene list, score_keys should be a string name if score_keys is None: sig_name = "enrichmap_signature" elif isinstance(score_keys, str): sig_name = score_keys else: # If score_keys is a list, use the first element sig_name = score_keys[0] if score_keys else "enrichmap_signature" gene_sets_dict = {sig_name: gene_sets} else: gene_sets_dict = gene_sets # Validate gene sets with format conversion available_genes = set(adata.var_names) validated_gene_sets = {} for sig_name, genes in gene_sets_dict.items(): common_genes = _match_gene_set_to_dataset(genes, available_genes, species) if len(common_genes) < 2: await ctx.warning( f"Signature '{sig_name}' has {len(common_genes)} genes in the dataset. Skipping." ) continue validated_gene_sets[sig_name] = common_genes await ctx.info( f"Signature '{sig_name}': {len(common_genes)}/{len(genes)} genes found" ) if not validated_gene_sets: raise ProcessingError( f"No valid gene signatures found (≥2 genes). " f"Dataset: {len(available_genes)} genes, requested: {len(gene_sets_dict)} signatures. " f"Check species (human/mouse) and gene name format." ) # Run EnrichMap scoring - process each gene set individually failed_signatures = [] successful_signatures = [] for sig_name, genes in validated_gene_sets.items(): try: em.tl.score( adata=adata, gene_set=genes, # Fixed: use gene_set (correct API parameter name) score_key=sig_name, # Fixed: provide explicit score_key spatial_key=spatial_key, n_neighbors=n_neighbors, smoothing=smoothing, correct_spatial_covariates=correct_spatial_covariates, batch_key=batch_key, ) successful_signatures.append(sig_name) except Exception as e: await ctx.warning(f"EnrichMap failed for '{sig_name}': {e}") failed_signatures.append((sig_name, str(e))) # Check if any signatures were processed successfully if not successful_signatures: error_details = "; ".join( [f"{name}: {error}" for name, error in failed_signatures] ) raise ProcessingError( f"All EnrichMap scoring failed. This may indicate:\n" f"1. EnrichMap package installation issues\n" f"2. Incompatible gene names or data format\n" f"3. Insufficient spatial information\n" f"Details: {error_details}" ) # Update validated_gene_sets to only include successful ones validated_gene_sets = { sig: validated_gene_sets[sig] for sig in successful_signatures } if ctx and failed_signatures: await ctx.warning( f"Failed to process {len(failed_signatures)} gene sets: {[name for name, _ in failed_signatures]}" ) # Collect results score_columns = [f"{sig}_score" for sig in validated_gene_sets] # Calculate summary statistics summary_stats = {} for sig_name in validated_gene_sets: score_col = f"{sig_name}_score" scores = adata.obs[score_col] summary_stats[sig_name] = { "mean": float(scores.mean()), "std": float(scores.std()), "min": float(scores.min()), "max": float(scores.max()), "median": float(scores.median()), "q25": float(scores.quantile(0.25)), "q75": float(scores.quantile(0.75)), "n_genes": len(validated_gene_sets[sig_name]), } # Per-run parametrized key so multiple database runs coexist analysis_key = _build_enrichment_key("spatial", database) # Store gene set membership (shared for compat + per-run for provenance) adata.uns["enrichment_spatial_gene_sets"] = validated_gene_sets per_run_gs_key = ( f"enrichment_spatial_gene_sets_{analysis_key.removeprefix('enrichment_')}" ) adata.uns[per_run_gs_key] = validated_gene_sets # Store metadata for scientific provenance tracking store_analysis_metadata( adata, analysis_name=analysis_key, method="spatial_enrichmap", parameters={ "spatial_key": spatial_key, "n_neighbors": n_neighbors, "smoothing": smoothing, "correct_spatial_covariates": correct_spatial_covariates, "batch_key": batch_key, }, results_keys={ "obs": score_columns, "uns": [per_run_gs_key], }, statistics={ "n_gene_sets": len(validated_gene_sets), "n_successful_signatures": len(successful_signatures), "n_failed_signatures": len(failed_signatures), }, species=species, database=database, ) # Export results for reproducibility export_analysis_result(adata, data_id, analysis_key) # Create enrichment scores (use max score per gene set) enrichment_scores = { sig_name: float(stats["max"]) for sig_name, stats in summary_stats.items() } # Sort by enrichment score to get top gene sets sorted_sigs = sorted(enrichment_scores.items(), key=lambda x: x[1], reverse=True) top_gene_sets = [sig_name for sig_name, _ in sorted_sigs[:10]] # Spatial enrichment doesn't provide p-values, so return empty gene_set_statistics # to reduce MCP response size (no significance filtering possible) pvalues = None adjusted_pvalues = None return EnrichmentResult( method="spatial_enrichmap", n_gene_sets=len(validated_gene_sets), # No significance testing in spatial enrichment — n_significant=0 # (successful computation != statistical significance) n_significant=0, n_successful_signatures=len(successful_signatures), enrichment_scores=enrichment_scores, pvalues=pvalues, adjusted_pvalues=adjusted_pvalues, gene_set_statistics={}, # Empty to reduce response size (no p-values available) spatial_scores_key=None, # Scores are in obs columns, not obsm top_gene_sets=top_gene_sets, top_depleted_sets=[], # Spatial enrichment doesn't produce depleted sets ) - chatspatial/tools/enrichment.py:609-867 (handler)perform_gsea: Gene Set Enrichment Analysis (GSEA) via gseapy prerank. Computes gene ranking (signal-to-noise, CV, variance, etc.), runs GSEA, filters significant pathways, stores results in adata.uns.
def perform_gsea( adata: "ad.AnnData", gene_sets: dict[str, list[str]], ranking_key: Optional[str] = None, method: str = "signal_to_noise", permutation_num: int = 1000, min_size: int = 10, max_size: int = 500, pvalue_cutoff: float = 0.25, species: Optional[str] = None, database: Optional[str] = None, ctx: Optional["ToolContext"] = None, data_id: Optional[str] = None, ) -> "EnrichmentResult": """Perform Gene Set Enrichment Analysis (GSEA). Args: adata: Annotated data matrix. gene_sets: Gene sets to test (name -> gene list). ranking_key: Key in adata.var for pre-computed ranking. Computes if None. method: Method for ranking genes if ranking_key is None. permutation_num: Number of permutations for p-value calculation. min_size: Minimum gene set size. max_size: Maximum gene set size. species: Species for the analysis ('mouse', 'human'). database: Gene set database used ('KEGG_Pathways', 'GO_Biological_Process'). ctx: MCP tool context for logging. data_id: Dataset identifier for result tracking. Returns: EnrichmentResult with enrichment scores and statistics. """ # gseapy imported at module level (required dependency) ranking_method = method.strip().lower() # Prepare ranking if ranking_key and ranking_key in adata.var: # Use pre-computed ranking ranking = adata.var[ranking_key].to_dict() else: # Compute ranking from expression data # Use get_raw_data_source (single source of truth) for complete gene coverage raw_result = get_raw_data_source(adata, prefer_complete_genes=True) X = raw_result.X var_names = raw_result.var_names # Compute gene ranking metric # IMPORTANT: GSEA requires biologically meaningful ranking, not just variance # Reference: Subramanian et al. (2005) PNAS, GSEA-MSIGDB documentation if ranking_method not in { "signal_to_noise", "coefficient_of_variation", "variance", "highly_variable_rank", "dispersions_norm", }: raise ParameterError( "Unsupported GSEA ranking method: " f"{method}. Supported: signal_to_noise, coefficient_of_variation, " "variance, highly_variable_rank, dispersions_norm" ) if ranking_method == "signal_to_noise": group_key = "condition" if "condition" in adata.obs else "group" groups = adata.obs[group_key].unique() if group_key in adata.obs else [] if len(groups) == 2: # Binary comparison: Use Signal-to-Noise Ratio (GSEA default) # S2N = (μ1 - μ2) / (σ1 + σ2) # This captures both differential expression AND expression stability group1_mask = adata.obs[group_key] == groups[0] group2_mask = adata.obs[group_key] == groups[1] # Compute means mean1 = np.array(X[group1_mask, :].mean(axis=0)).flatten() mean2 = np.array(X[group2_mask, :].mean(axis=0)).flatten() # Compute standard deviations (sparse-compatible) std1 = _compute_std_sparse_compatible(X[group1_mask, :], axis=0, ddof=1) std2 = _compute_std_sparse_compatible(X[group2_mask, :], axis=0, ddof=1) # Apply minimum std threshold (GSEA standard: 0.2 * |mean|) # This prevents division by zero and reduces noise from low-variance genes min_std_factor = 0.2 std1 = np.maximum(std1, min_std_factor * np.abs(mean1)) std2 = np.maximum(std2, min_std_factor * np.abs(mean2)) # Compute Signal-to-Noise Ratio s2n = (mean1 - mean2) / (std1 + std2) ranking = dict(zip(var_names, s2n, strict=True)) else: # S2N requires exactly 2 groups; fall back to CV when unavailable. ranking = _compute_cv_ranking(X, var_names) elif ranking_method == "coefficient_of_variation": ranking = _compute_cv_ranking(X, var_names) elif ranking_method == "variance": ranking = _compute_variance_ranking(X, var_names) elif ranking_method == "highly_variable_rank": if "highly_variable_rank" not in adata.var: raise DataNotFoundError( "Ranking method 'highly_variable_rank' requested but " "adata.var['highly_variable_rank'] is missing." ) ranking = adata.var["highly_variable_rank"].to_dict() elif ranking_method == "dispersions_norm": if "dispersions_norm" not in adata.var: raise DataNotFoundError( "Ranking method 'dispersions_norm' requested but " "adata.var['dispersions_norm'] is missing." ) ranking = adata.var["dispersions_norm"].to_dict() else: # pragma: no cover - defensive fallback, guarded by method validation # Defensive fallback (should be unreachable due validation above) ranking = _compute_cv_ranking(X, var_names) # Run GSEA preranked try: # Convert ranking dict to DataFrame for gseapy ranking_df = pd.DataFrame.from_dict(ranking, orient="index", columns=["score"]) ranking_df.index.name = "gene" ranking_df = ranking_df.sort_values("score", ascending=False) res = _get_gseapy().prerank( rnk=ranking_df, # Pass DataFrame instead of dict gene_sets=gene_sets, processes=1, permutation_num=permutation_num, min_size=min_size, max_size=max_size, seed=42, verbose=False, no_plot=True, outdir=None, ) # Extract results results_df = res.res2d # Prepare output - OPTIMIZED: vectorized dict + array iteration (16x faster) enrichment_scores = dict(zip(results_df["Term"], results_df["ES"])) pvalues = dict(zip(results_df["Term"], results_df["NOM p-val"])) adjusted_pvalues = dict(zip(results_df["Term"], results_df["FDR q-val"])) # Pre-extract arrays for fast iteration terms = results_df["Term"].values es_vals = results_df["ES"].values nes_vals = results_df["NES"].values pval_vals = results_df["NOM p-val"].values fdr_vals = results_df["FDR q-val"].values has_matched_size = "Matched_size" in results_df.columns has_lead_genes = "Lead_genes" in results_df.columns size_vals = ( results_df["Matched_size"].values if has_matched_size else np.zeros(len(terms)) ) lead_genes_vals = ( results_df["Lead_genes"].values if has_lead_genes else [""] * len(terms) ) gene_set_statistics = {} for i in range(len(terms)): lead_genes_str = lead_genes_vals[i] if has_lead_genes else "" gene_set_statistics[terms[i]] = { "es": float(es_vals[i]), "nes": float(nes_vals[i]), "pval": float(pval_vals[i]), "fdr": float(fdr_vals[i]), "size": int(size_vals[i]) if has_matched_size else 0, "lead_genes": lead_genes_str.split(";")[:10] if lead_genes_str else [], } # Get top enriched and depleted results_df_sorted = results_df.sort_values("NES", ascending=False) top_enriched = ( results_df_sorted[results_df_sorted["NES"] > 0].head(10)["Term"].tolist() ) top_depleted = ( results_df_sorted[results_df_sorted["NES"] < 0].head(10)["Term"].tolist() ) # Save results to adata.uns for visualization # Store full results DataFrame for visualization (shared key for viz compat) adata.uns["gsea_results"] = results_df # Per-run parametrized key so multiple database runs coexist analysis_key = _build_enrichment_key("gsea", database) per_run_key = f"gsea_results_{analysis_key.removeprefix('enrichment_')}" adata.uns[per_run_key] = results_df # Store gene set membership (shared for compat + per-run for provenance) adata.uns["enrichment_gsea_gene_sets"] = gene_sets per_run_gs_key = ( f"enrichment_gsea_gene_sets_{analysis_key.removeprefix('enrichment_')}" ) adata.uns[per_run_gs_key] = gene_sets # Store metadata for scientific provenance tracking store_analysis_metadata( adata, analysis_name=analysis_key, method="gsea", parameters={ "permutation_num": permutation_num, "ranking_method": ranking_method, "min_size": min_size, "max_size": max_size, "ranking_key": ranking_key, }, results_keys={ "uns": [per_run_key, per_run_gs_key], }, statistics={ "n_gene_sets": len(gene_sets), "n_significant": len( results_df[results_df["FDR q-val"] < pvalue_cutoff] ), }, species=species, database=database, ) # Export results to CSV for reproducibility if data_id is not None: export_analysis_result(adata, data_id, analysis_key) # Filter all result dictionaries to only significant pathways (reduces MCP response size) # Uses method-based FDR threshold: GSEA = 0.25 (Subramanian et al. 2005) ( filtered_statistics, filtered_scores, filtered_pvals, filtered_adj_pvals, ) = _filter_significant_statistics( gene_set_statistics, enrichment_scores, pvalues, adjusted_pvalues, method="gsea", fdr_threshold=pvalue_cutoff, ) return EnrichmentResult( method="gsea", n_gene_sets=len(gene_sets), n_significant=len(results_df[results_df["FDR q-val"] < pvalue_cutoff]), enrichment_scores=filtered_scores, pvalues=filtered_pvals, adjusted_pvalues=filtered_adj_pvals, gene_set_statistics=filtered_statistics, top_gene_sets=top_enriched, top_depleted_sets=top_depleted, ) except Exception as e: logger.error(f"GSEA failed: {e}") raise - chatspatial/tools/enrichment.py:870-1133 (handler)perform_ora: Over-Representation Analysis (ORA) using Fisher's exact test. Gets DEGs or HVGs as query, tests overlap with gene sets, applies multiple testing correction, returns filtered EnrichmentResult.
def perform_ora( adata: "ad.AnnData", gene_sets: dict[str, list[str]], gene_list: Optional[list[str]] = None, pvalue_threshold: float = 0.05, significance_threshold: Optional[float] = None, min_size: int = 10, max_size: int = 500, adjust_method: str = "fdr", species: Optional[str] = None, database: Optional[str] = None, ctx: Optional["ToolContext"] = None, data_id: Optional[str] = None, ) -> "EnrichmentResult": """Perform Over-Representation Analysis (ORA). Args: adata: Annotated data matrix. gene_sets: Gene sets to test (name -> gene list). gene_list: Genes to test. Uses DEGs from rank_genes_groups if None. pvalue_threshold: P-value threshold for selecting DEGs from rank_genes_groups (only used when gene_list is None). significance_threshold: Adjusted p-value threshold for counting and filtering significant pathways in ORA results. If None, defaults to 0.05 (standard Benjamini-Hochberg FDR control). min_size: Minimum gene set size. max_size: Maximum gene set size. adjust_method: Multiple testing correction ('fdr', 'bonferroni', 'none'). species: Species for the analysis ('mouse', 'human'). database: Gene set database used ('KEGG_Pathways', 'GO_Biological_Process'). ctx: MCP tool context for logging. data_id: Dataset identifier for result tracking. Returns: EnrichmentResult with enrichment scores and statistics. Note: LogFC filtering removed. ORA should use genes pre-filtered by find_markers. Gene filtering is the responsibility of differential expression analysis. """ # Default significance threshold for ORA: standard FDR 5% if significance_threshold is None: significance_threshold = 0.05 # Get gene list if not provided if gene_list is None: # Try to get DEGs from adata if "rank_genes_groups" in adata.uns: # Get DEGs result = adata.uns["rank_genes_groups"] names = result["names"] # Check if pvals exist (not all rank_genes_groups have pvals) pvals = None if "pvals_adj" in result: pvals = result["pvals_adj"] elif "pvals" in result: pvals = result["pvals"] # Get DEGs from all groups and merge # IMPORTANT: names is a numpy recarray with shape (n_genes,) # and dtype.names contains group names as fields # Access genes by group name: names[group_name][i] degs_seen: set[str] = set() degs_ordered: list[str] = [] # Iterate over all groups for group_name in names.dtype.names: for i in range(len(names)): # Skip genes that don't pass filter criteria if pvals is not None and pvals[group_name][i] >= pvalue_threshold: continue if pvals is None and i >= 100: # Top 100 genes when no pvals continue gene_name = str(names[group_name][i]) if gene_name not in degs_seen: degs_seen.add(gene_name) degs_ordered.append(gene_name) gene_list = degs_ordered else: # Use highly variable genes if "highly_variable" in adata.var: gene_list = adata.var_names[adata.var["highly_variable"]].tolist() else: # Use top variable genes (based on Coefficient of Variation) # CV = σ/μ is more appropriate than raw variance mean = np.array(adata.X.mean(axis=0)).flatten() std = _compute_std_sparse_compatible(adata.X, axis=0, ddof=1) # Compute CV (avoid division by zero) cv = np.zeros_like(mean) nonzero_mask = np.abs(mean) > 1e-10 cv[nonzero_mask] = std[nonzero_mask] / np.abs(mean[nonzero_mask]) top_indices = top_n_desc_indices(cv, 500, sanitize_nonfinite=True) gene_list = adata.var_names[top_indices].tolist() # Background genes # Use get_raw_data_source (single source of truth) to get complete gene set # This handles gene name casing differences between raw and filtered data bg_result = get_raw_data_source(adata, prefer_complete_genes=True) background_genes = set(bg_result.var_names) # Case-insensitive matching as fallback for gene name format differences # (e.g., MT.CO1 vs MT-CO1, uppercase vs lowercase) query_genes = set(gene_list) & background_genes # If no direct matches, try case-insensitive matching if len(query_genes) == 0 and len(gene_list) > 0: # Create case-insensitive lookup gene_name_map = {g.upper(): g for g in background_genes} query_genes = set() for gene in gene_list: if gene.upper() in gene_name_map: query_genes.add(gene_name_map[gene.upper()]) # Perform hypergeometric test for each gene set enrichment_scores = {} pvalues = {} gene_set_statistics = {} for gs_name, gs_genes in gene_sets.items(): gs_genes_set = set(gs_genes) & background_genes if len(gs_genes_set) < min_size or len(gs_genes_set) > max_size: continue # Hypergeometric test # a: genes in both query and gene set # b: genes in query but not in gene set # c: genes in gene set but not in query # d: genes in neither a = len(query_genes & gs_genes_set) b = len(query_genes - gs_genes_set) c = len(gs_genes_set - query_genes) d = len(background_genes - query_genes - gs_genes_set) # Fisher's exact test odds_ratio, p_value = stats.fisher_exact( [[a, b], [c, d]], alternative="greater" ) enrichment_scores[gs_name] = odds_ratio pvalues[gs_name] = p_value gene_set_statistics[gs_name] = { "odds_ratio": odds_ratio, "pval": p_value, "overlap": a, "query_size": len(query_genes), "gs_size": len(gs_genes_set), "overlapping_genes": list(query_genes & gs_genes_set)[:20], # Top 20 } # Multiple testing correction _method_map = {"fdr": "fdr_bh", "bonferroni": "bonferroni"} if pvalues and adjust_method != "none": stats_method = _method_map.get(adjust_method, "fdr_bh") pval_array = np.array(list(pvalues.values())) _, adjusted_pvals, _, _ = multipletests(pval_array, method=stats_method) adjusted_pvalues = dict(zip(pvalues.keys(), adjusted_pvals, strict=False)) else: # No correction: adjusted = raw adjusted_pvalues = dict(pvalues) if pvalues else {} # Get top results sorted_by_pval = sorted(pvalues.items(), key=lambda x: x[1]) top_gene_sets = [x[0] for x in sorted_by_pval[:10]] # Save results to adata.uns for visualization # Create DataFrame for visualization compatibility ora_df = pd.DataFrame( { "pathway": list(enrichment_scores), "odds_ratio": list(enrichment_scores.values()), "pvalue": [pvalues.get(k, 1.0) for k in enrichment_scores], "adjusted_pvalue": [ adjusted_pvalues.get(k, 1.0) for k in enrichment_scores ], } ) ora_df["NES"] = ora_df["odds_ratio"] # Use odds_ratio as score for visualization ora_df = ora_df.sort_values("pvalue") # Shared key for visualization compatibility adata.uns["ora_results"] = ora_df # Per-run parametrized key so multiple database runs coexist analysis_key = _build_enrichment_key("ora", database) per_run_key = f"ora_results_{analysis_key.removeprefix('enrichment_')}" adata.uns[per_run_key] = ora_df # Store gene set membership (shared for compat + per-run for provenance) adata.uns["enrichment_ora_gene_sets"] = gene_sets per_run_gs_key = ( f"enrichment_ora_gene_sets_{analysis_key.removeprefix('enrichment_')}" ) adata.uns[per_run_gs_key] = gene_sets # Count significant pathways using the pathway significance threshold # (distinct from the DEG selection threshold used for input gene filtering) n_significant = sum( 1 for p in adjusted_pvalues.values() if p is not None and p < significance_threshold ) # Store metadata for scientific provenance tracking store_analysis_metadata( adata, analysis_name=analysis_key, method="ora", parameters={ "deg_pvalue_threshold": pvalue_threshold, "significance_threshold": significance_threshold, "adjust_method": adjust_method, "min_size": min_size, "max_size": max_size, "n_query_genes": len(query_genes), }, results_keys={ "uns": [per_run_key, per_run_gs_key], }, statistics={ "n_gene_sets": len(gene_sets), "n_significant": n_significant, "n_query_genes": len(query_genes), }, species=species, database=database, ) # Export results to CSV for reproducibility if data_id is not None: export_analysis_result(adata, data_id, analysis_key) # Filter result dicts to significant pathways (reduces MCP response size) ( filtered_statistics, filtered_scores, filtered_pvals, filtered_adj_pvals, ) = _filter_significant_statistics( gene_set_statistics, enrichment_scores, pvalues, adjusted_pvalues, method="ora", fdr_threshold=significance_threshold, ) return EnrichmentResult( method="ora", n_gene_sets=len(gene_sets), n_significant=n_significant, enrichment_scores=filtered_scores, pvalues=filtered_pvals, adjusted_pvalues=filtered_adj_pvals, gene_set_statistics=filtered_statistics, top_gene_sets=top_gene_sets, top_depleted_sets=[], # ORA does not produce depleted gene sets ) - chatspatial/server.py:660-693 (handler)MCP tool registration of analyze_enrichment. Decorated with @mcp.tool, defines the public API (data_id, params: EnrichmentParameters, context), delegates to tools/enrichment.py's analyze_enrichment, saves result via data_manager.
@mcp.tool( annotations=ToolAnnotations( readOnlyHint=False, idempotentHint=True, openWorldHint=True, ) ) @mcp_tool_error_handler() async def analyze_enrichment( data_id: str, params: EnrichmentParameters, context: Optional[Context] = None, ) -> EnrichmentResult: """Perform gene set enrichment analysis. Args: data_id: Dataset ID params: Required - species must be specified. See EnrichmentParameters for methods and gene_set_database options. """ from .tools.enrichment import analyze_enrichment as analyze_enrichment_func # Create ToolContext ctx = ToolContext(_data_manager=data_manager, _mcp_context=context) # Call enrichment analysis (all business logic is in tools/enrichment.py) result = await analyze_enrichment_func(data_id, ctx, params) # Save result (keyed by method + database to allow coexistence) from .tools.enrichment import _build_enrichment_key cache_key = _build_enrichment_key(params.method, params.gene_set_database) await data_manager.save_result(data_id, cache_key, result) return result - chatspatial/server.py:660-666 (registration)@mcp.tool decorator registration for analyze_enrichment tool with annotations (readOnlyHint=False, idempotentHint=True, openWorldHint=True).
@mcp.tool( annotations=ToolAnnotations( readOnlyHint=False, idempotentHint=True, openWorldHint=True, ) ) - chatspatial/models/data.py:1709-1772 (schema)EnrichmentParameters Pydantic model: defines input schema for enrichment analysis - species (required), method, gene_sets, gene_set_database, spatial parameters, min/max genes, statistical parameters.
class EnrichmentParameters(BaseModel): """Parameters for gene set enrichment analysis""" model_config = ConfigDict(extra="forbid") # REQUIRED: Species specification (no default value) species: Literal["human", "mouse", "zebrafish"] # Must explicitly specify the species for gene set matching: # - "human": For human data (genes like CD5L, PTPRC - all uppercase) # - "mouse": For mouse data (genes like Cd5l, Ptprc - capitalize format) # - "zebrafish": For zebrafish data # Method selection method: Literal[ "spatial_enrichmap", "pathway_gsea", "pathway_ora", "pathway_enrichr", "pathway_ssgsea", ] = Field( default="spatial_enrichmap", description="'spatial_enrichmap' for spatial patterns. 'pathway_gsea'/'pathway_ora' for standard enrichment.", ) # Gene sets gene_sets: Optional[Union[list[str], dict[str, list[str]]]] = ( None # Gene sets to analyze ) score_keys: Optional[Union[str, list[str]]] = None # Names for gene signatures # Gene set database - choose species-appropriate option gene_set_database: Optional[ Literal[ "GO_Biological_Process", # Default (auto-adapts to species) "GO_Molecular_Function", # GO molecular function terms "GO_Cellular_Component", # GO cellular component terms "KEGG_Pathways", # KEGG pathways (species-specific: human=2021, mouse=2019) "Reactome_Pathways", # Reactome pathway database (2022 version) "MSigDB_Hallmark", # MSigDB hallmark gene sets (2020 version) "Cell_Type_Markers", # Cell type marker genes ] ] = "GO_Biological_Process" # Spatial parameters (for spatial_enrichmap) spatial_key: str = "spatial" n_neighbors: int = Field( default=6, gt=0, description="Spatial neighbors for enrichment mapping." ) smoothing: bool = True correct_spatial_covariates: bool = True # Analysis parameters batch_key: Optional[str] = None min_genes: int = Field(default=10, gt=0, description="Minimum genes in gene set.") max_genes: int = Field(default=500, gt=0, description="Maximum genes in gene set.") # Statistical parameters pvalue_cutoff: float = Field( default=0.05, gt=0.0, lt=1.0, description="P-value significance cutoff." ) adjust_method: Literal["bonferroni", "fdr", "none"] = "fdr" n_permutations: int = Field( default=1000, gt=0, description="Permutations for GSEA." ) - EnrichmentResult Pydantic model: output schema with method, n_gene_sets, n_significant, top_gene_sets, top_depleted_sets. Large dicts (enrichment_scores, pvalues, etc.) excluded from MCP response.
class EnrichmentResult(BaseAnalysisResult): """Result from gene set enrichment analysis Note on serialization: To minimize MCP response size (~12k tokens -> ~0.5k tokens), large dictionaries are excluded from JSON serialization using Field(exclude=True). These fields are still stored in the Python object and saved to adata.uns for downstream visualization. Fields included in MCP response (sent to LLM): - method, n_gene_sets, n_significant (basic info) - top_gene_sets, top_depleted_sets (top 10 pathway names) - spatial_scores_key (for spatial methods) Fields excluded from MCP response (stored in adata.uns): - enrichment_scores, pvalues, adjusted_pvalues (full dicts) - gene_set_statistics (detailed stats per pathway) - spatial_metrics (spatial autocorrelation data) """ # Basic information - always included in MCP response method: str # Method used (pathway_gsea, pathway_ora, etc.) n_gene_sets: int # Number of gene sets analyzed n_significant: int # Number of statistically significant gene sets (0 when no test) # Top results - always included (compact, just pathway names) top_gene_sets: list[str] # Top enriched gene sets (max 10) top_depleted_sets: list[str] # Top depleted gene sets (max 10) # Spatial enrichment: number of signatures successfully computed # (distinct from n_significant which requires statistical testing) n_successful_signatures: Optional[int] = None # Spatial info key - included spatial_scores_key: Optional[str] = None # Key in adata.obsm # ============================================================ # EXCLUDED FROM MCP RESPONSE - stored in adata.uns for viz # Full data available via visualize_data() tool # ============================================================ enrichment_scores: dict[str, float] = Field( default_factory=dict, exclude=True, # Exclude from JSON serialization to LLM ) pvalues: Optional[dict[str, float]] = Field( default=None, exclude=True, ) adjusted_pvalues: Optional[dict[str, float]] = Field( default=None, exclude=True, ) gene_set_statistics: dict[str, dict[str, Any]] = Field( default_factory=dict, exclude=True, ) spatial_metrics: Optional[dict[str, Any]] = Field( default=None, exclude=True, )