"""
RNA velocity analysis for spatial transcriptomics.
This module computes RNA velocity to infer the direction of cellular state changes
by analyzing the balance of spliced and unspliced mRNA counts.
Key functionality:
- `analyze_rna_velocity`: Main MCP entry point for velocity analysis
- Supports scVelo (standard) and VELOVI (deep learning) methods
"""
from typing import TYPE_CHECKING, Any, Optional
import numpy as np
if TYPE_CHECKING:
import anndata as ad
from ..spatial_mcp_adapter import ToolContext
from ..models.analysis import RNAVelocityResult
from ..models.data import RNAVelocityParameters
from ..utils.adata_utils import (
store_analysis_metadata,
store_velovi_essential_data,
validate_adata,
)
from ..utils.dependency_manager import require
from ..utils.exceptions import (
DataError,
DataNotFoundError,
ParameterError,
ProcessingError,
)
from ..utils.mcp_utils import suppress_output
from ..utils.results_export import export_analysis_result
def preprocess_for_velocity(
adata: "ad.AnnData",
min_shared_counts: int = 30,
n_top_genes: int = 2000,
n_pcs: int = 30,
n_neighbors: int = 30,
params: Optional[RNAVelocityParameters] = None,
) -> "ad.AnnData":
"""Prepare AnnData for RNA velocity analysis using scVelo pipeline.
Performs the standard scVelo preprocessing workflow:
1. Filtering genes based on minimum shared counts between spliced/unspliced layers
2. Normalizing the data
3. Selecting highly variable genes
4. Computing first and second-order moments across nearest neighbors
Args:
adata: Annotated data matrix with 'spliced' and 'unspliced' layers.
min_shared_counts: Minimum counts shared between spliced and unspliced layers.
n_top_genes: Number of highly variable genes to use.
n_pcs: Number of principal components to compute.
n_neighbors: Number of nearest neighbors for moment computation.
params: If provided, overrides the individual parameters.
Returns:
Preprocessed AnnData with velocity-ready layers and moments.
"""
import scvelo as scv
# If params object is provided, use its values
if params is not None:
from ..models.data import RNAVelocityParameters
if isinstance(params, RNAVelocityParameters):
min_shared_counts = params.min_shared_counts
n_top_genes = params.n_top_genes
n_pcs = params.n_pcs
n_neighbors = params.n_neighbors
# Validate velocity data
try:
validate_adata(adata, {}, check_velocity=True)
except DataNotFoundError as e:
raise DataError(f"Invalid velocity data: {e}") from e
# Standard preprocessing with configurable parameters
# enforce=True ensures scvelo recomputes everything even if data was pre-normalized
# This is important when running after MCP's general preprocessing step
scv.pp.filter_and_normalize(
adata,
min_shared_counts=min_shared_counts,
n_top_genes=n_top_genes,
enforce=True,
)
scv.pp.moments(adata, n_pcs=n_pcs, n_neighbors=n_neighbors)
return adata
def compute_rna_velocity(
adata: "ad.AnnData",
mode: str = "stochastic",
params: Optional[RNAVelocityParameters] = None,
) -> "ad.AnnData":
"""Compute RNA velocity to infer cellular differentiation direction.
Executes the core RNA velocity workflow:
1. Ensures preprocessing (moment computation) is complete
2. Estimates RNA velocity using the specified model
3. Constructs a velocity graph for cell-to-cell transitions
Args:
adata: Annotated data matrix with 'spliced' and 'unspliced' layers.
mode: Model for velocity estimation:
- 'stochastic': Likelihood-based model accounting for noise
- 'deterministic': Simpler steady-state model
- 'dynamical': Full transcriptional dynamics with ODE fitting
params: Parameter object (mode extracted from params.scvelo_mode).
Returns:
AnnData updated with velocity vectors and graph.
"""
import scvelo as scv
# Use params for mode if provided
if params is not None:
from ..models.data import RNAVelocityParameters
if isinstance(params, RNAVelocityParameters):
mode = params.scvelo_mode
# Check if preprocessing is needed
if "Ms" not in adata.layers or "Mu" not in adata.layers:
adata = preprocess_for_velocity(adata, params=params)
# Compute velocity based on mode
if mode == "dynamical":
scv.tl.recover_dynamics(adata)
scv.tl.velocity(adata, mode="dynamical")
# Compute latent time (required for gene_trends visualization)
scv.tl.latent_time(adata)
else:
scv.tl.velocity(adata, mode=mode)
# Compute velocity graph
scv.tl.velocity_graph(adata)
return adata
async def _prepare_velovi_data(
adata: "ad.AnnData",
ctx: Optional["ToolContext"],
) -> "ad.AnnData":
"""Prepare data for VELOVI according to official standards.
Args:
adata: Input AnnData with spliced/unspliced layers.
ctx: ToolContext for logging warnings.
Returns:
Preprocessed AnnData copy ready for VELOVI.
"""
import scvelo as scv
adata_velovi = adata.copy()
# Convert layer names to VELOVI standards
if "spliced" in adata_velovi.layers and "unspliced" in adata_velovi.layers:
adata_velovi.layers["Ms"] = adata_velovi.layers["spliced"]
adata_velovi.layers["Mu"] = adata_velovi.layers["unspliced"]
else:
raise DataNotFoundError("Missing required 'spliced' and 'unspliced' layers")
# scvelo preprocessing
# enforce=True ensures scvelo recomputes everything even if data was pre-normalized
try:
scv.pp.filter_and_normalize(
adata_velovi, min_shared_counts=30, n_top_genes=2000, enforce=True
)
except Exception as e:
if ctx:
await ctx.warning(f"scvelo preprocessing warning: {e}")
# Compute moments
try:
scv.pp.moments(adata_velovi, n_pcs=30, n_neighbors=30)
except Exception as e:
if ctx:
await ctx.warning(f"moments computation warning: {e}")
return adata_velovi
def _validate_velovi_data(adata: "ad.AnnData") -> bool:
"""Validate data for VELOVI requirements.
Args:
adata: AnnData to validate for VELOVI compatibility.
Returns:
True if validation passes.
Raises:
DataNotFoundError: If required layers are missing.
DataError: If layer shapes are incompatible.
"""
if "Ms" not in adata.layers or "Mu" not in adata.layers:
raise DataNotFoundError("Missing required layers 'Ms' and 'Mu' for VELOVI")
ms_data = adata.layers["Ms"]
mu_data = adata.layers["Mu"]
if ms_data.shape != mu_data.shape:
raise DataError(f"Shape mismatch: Ms {ms_data.shape} vs Mu {mu_data.shape}")
if ms_data.ndim != 2 or mu_data.ndim != 2:
raise DataError(
f"Expected 2D arrays, got Ms:{ms_data.ndim}D, Mu:{mu_data.ndim}D"
)
return True
async def analyze_velocity_with_velovi(
adata: "ad.AnnData",
n_epochs: int = 1000,
n_hidden: int = 128,
n_latent: int = 10,
use_gpu: bool = False,
ctx: Optional["ToolContext"] = None,
) -> dict[str, Any]:
"""
Analyzes RNA velocity using the deep learning model VELOVI.
VELOVI (Velocity Variational Inference) is a probabilistic deep generative model
that estimates transcriptional dynamics from spliced and unspliced mRNA counts.
It provides velocity vectors with uncertainty quantification.
Args:
adata: AnnData with 'spliced' and 'unspliced' layers.
n_epochs: Number of training epochs.
n_hidden: Number of hidden units in neural network layers.
n_latent: Dimensionality of the latent space.
use_gpu: If True, use GPU for training.
ctx: ToolContext for logging.
Returns:
Dictionary with VELOVI results and metadata.
"""
try:
require("scvi", feature="VELOVI velocity analysis")
from scvi.external import VELOVI
# Data preprocessing
adata_prepared = await _prepare_velovi_data(adata, ctx)
# Data validation
_validate_velovi_data(adata_prepared)
# VELOVI setup
VELOVI.setup_anndata(
adata_prepared,
spliced_layer="Ms",
unspliced_layer="Mu",
)
# Model creation
velovi_model = VELOVI(adata_prepared, n_hidden=n_hidden, n_latent=n_latent)
# Model training
if use_gpu:
velovi_model.train(max_epochs=n_epochs, accelerator="gpu")
else:
velovi_model.train(max_epochs=n_epochs)
# Result extraction
latent_time = velovi_model.get_latent_time(n_samples=25)
velocities = velovi_model.get_velocity(n_samples=25, velo_statistic="mean")
latent_repr = velovi_model.get_latent_representation()
# Handle pandas/numpy compatibility
if hasattr(latent_time, "values"):
latent_time = latent_time.values
if hasattr(velocities, "values"):
velocities = velocities.values
# Ensure numpy array format
latent_time = np.asarray(latent_time)
velocities = np.asarray(velocities)
latent_repr = np.asarray(latent_repr)
# Safe scaling calculation
t = latent_time
if t.ndim > 1:
t_max = np.max(t, axis=0)
if np.all(t_max > 0):
scaling = 20 / t_max
else:
scaling = np.where(t_max > 0, 20 / t_max, 1.0)
else:
t_max = np.max(t)
scaling = 20 / t_max if t_max > 0 else 1.0
if hasattr(scaling, "to_numpy"):
scaling = scaling.to_numpy()
scaling = np.asarray(scaling)
# Calculate scaled velocities
if scaling.ndim == 0:
scaled_velocities = velocities / scaling
elif scaling.ndim == 1 and velocities.ndim == 2:
scaled_velocities = velocities / scaling[np.newaxis, :]
else:
scaled_velocities = velocities / scaling
# Store results in preprocessed data object
adata_prepared.layers["velocity_velovi"] = scaled_velocities
adata_prepared.layers["latent_time_velovi"] = latent_time
adata_prepared.obsm["X_velovi_latent"] = latent_repr
# Calculate velocity statistics
velocity_norm = np.linalg.norm(scaled_velocities, axis=1)
adata_prepared.obs["velocity_velovi_norm"] = velocity_norm
# Transfer key information back to original adata
adata.obs["velocity_velovi_norm"] = velocity_norm
adata.obsm["X_velovi_latent"] = latent_repr
# Store essential data for CellRank (optimized: ~78% memory savings)
# Instead of storing full adata (~160 MB), stores only velocity/neighbors (~35 MB)
store_velovi_essential_data(adata, adata_prepared)
return {
"method": "VELOVI",
"velocity_computed": True,
"n_epochs": n_epochs,
"n_hidden": n_hidden,
"n_latent": n_latent,
"velocity_shape": scaled_velocities.shape,
"latent_time_shape": latent_time.shape,
"latent_repr_shape": latent_repr.shape,
"velocity_mean_norm": float(velocity_norm.mean()),
"velocity_std_norm": float(velocity_norm.std()),
"n_genes_analyzed": adata_prepared.n_vars,
"original_n_genes": adata.n_vars,
"training_completed": True,
"device": "GPU" if use_gpu else "CPU",
}
except Exception as e:
raise ProcessingError(f"VELOVI velocity analysis failed: {e}") from e
async def analyze_rna_velocity(
data_id: str,
ctx: "ToolContext",
params: RNAVelocityParameters = RNAVelocityParameters(),
) -> RNAVelocityResult:
"""
Computes RNA velocity for spatial transcriptomics data.
This is the main MCP entry point for velocity analysis. It requires
'spliced' and 'unspliced' count layers in the input dataset.
Args:
data_id: Dataset identifier.
ctx: ToolContext for data access and logging.
params: RNA velocity parameters.
Returns:
RNAVelocityResult with computation metadata.
Raises:
DataNotFoundError: If data lacks required layers.
ProcessingError: If velocity computation fails.
"""
require("scvelo")
import scvelo as scv # noqa: F401
# Get AnnData object
adata = await ctx.get_adata(data_id)
# Validate data for velocity analysis
try:
validate_adata(adata, {}, check_velocity=True)
except DataNotFoundError as e:
raise DataNotFoundError(
f"Missing velocity data: {e}. Requires 'spliced' and 'unspliced' layers."
) from e
velocity_computed = False
velocity_method_used = params.method
# Dispatch based on method
if params.method == "scvelo":
with suppress_output():
try:
adata = compute_rna_velocity(
adata, mode=params.scvelo_mode, params=params
)
velocity_computed = True
except Exception as e:
raise ProcessingError(
f"scVelo RNA velocity analysis failed: {e}"
) from e
elif params.method == "velovi":
require("scvi", feature="VELOVI velocity analysis")
try:
velovi_results = await analyze_velocity_with_velovi(
adata,
n_epochs=params.velovi_n_epochs,
n_hidden=params.velovi_n_hidden,
n_latent=params.velovi_n_latent,
use_gpu=params.velovi_use_gpu,
ctx=ctx,
)
if velovi_results.get("velocity_computed", False):
velocity_computed = True
adata.uns["velocity_graph"] = True
adata.uns["velocity_method"] = "velovi"
else:
raise ProcessingError("VELOVI failed to compute velocity")
except Exception as e:
raise ProcessingError(f"VELOVI velocity analysis failed: {e}") from e
else:
raise ParameterError(f"Unknown velocity method: {params.method}")
# Build results keys based on what was computed
# Note: velocity layers NOT exported (too large for CSV)
method_used = velocity_method_used if params.method == "scvelo" else params.method
results_keys: dict[str, list[str]] = {
"uns": ["velocity_method"],
"obs": [],
"obsm": [],
}
# VELOVI results
if "velocity_velovi_norm" in adata.obs:
results_keys["obs"].append("velocity_velovi_norm")
if "X_velovi_latent" in adata.obsm:
results_keys["obsm"].append("X_velovi_latent")
# scvelo dynamical mode results
if "latent_time" in adata.obs:
results_keys["obs"].append("latent_time")
# Store metadata for scientific provenance tracking
store_analysis_metadata(
adata,
analysis_name=f"velocity_{method_used}",
method=method_used,
parameters={
"n_top_genes": params.n_top_genes,
"n_pcs": params.n_pcs,
"n_neighbors": params.n_neighbors,
"min_shared_counts": params.min_shared_counts,
},
results_keys=results_keys,
statistics={
"velocity_computed": velocity_computed,
},
)
# Export results for reproducibility
export_analysis_result(adata, data_id, f"velocity_{method_used}")
return RNAVelocityResult(
data_id=data_id,
velocity_computed=velocity_computed,
velocity_graph_key="velocity_graph" if velocity_computed else None,
mode=velocity_method_used if params.method == "scvelo" else params.method,
)