"""
RCTD (Robust Cell Type Decomposition) deconvolution method.
RCTD is an R-based deconvolution method that performs robust
decomposition of cell type mixtures via the spacexr package.
"""
import warnings
from typing import TYPE_CHECKING, Any
import numpy as np
import pandas as pd
if TYPE_CHECKING:
pass
from ...utils.dependency_manager import validate_r_package
from ...utils.exceptions import DataError, ParameterError, ProcessingError
from .base import PreparedDeconvolutionData, create_deconvolution_stats
def deconvolve(
data: PreparedDeconvolutionData,
mode: str = "full",
max_cores: int = 4,
confidence_threshold: float = 10.0,
doublet_threshold: float = 25.0,
max_multi_types: int = 4,
) -> tuple[pd.DataFrame, dict[str, Any]]:
"""Deconvolve spatial data using RCTD from spacexr R package.
Args:
data: Prepared deconvolution data (immutable, includes spatial coordinates)
mode: RCTD mode - 'full', 'doublet', or 'multi'
max_cores: Maximum CPU cores
confidence_threshold: Confidence threshold
doublet_threshold: Doublet detection threshold
max_multi_types: Max cell types per spot in multi mode
Returns:
Tuple of (proportions DataFrame, statistics dictionary)
"""
import anndata2ri
import rpy2.robjects as ro
from rpy2.robjects import pandas2ri
from rpy2.robjects.conversion import localconverter
ctx = data.ctx
# Validate mode-specific parameters
if mode == "multi" and max_multi_types >= data.n_cell_types:
raise ParameterError(
f"MAX_MULTI_TYPES ({max_multi_types}) must be less than "
f"total cell types ({data.n_cell_types})."
)
# Validate R package
validate_r_package(
"spacexr",
ctx,
install_cmd="devtools::install_github('dmcable/spacexr', build_vignettes = FALSE)",
)
try:
# Load R packages using ro.r() instead of importr() to avoid
# conversion context issues in async environments
with localconverter(ro.default_converter + pandas2ri.converter):
ro.r("library(spacexr)")
# Data already copied in prepare_deconvolution
spatial_data = data.spatial
reference_data = data.reference
# Get spatial coordinates from prepared data
if data.spatial_coords is not None:
coords = pd.DataFrame(
data.spatial_coords[:, :2],
index=spatial_data.obs_names,
columns=["x", "y"],
)
else:
coords = pd.DataFrame(
{"x": range(spatial_data.n_obs), "y": [0] * spatial_data.n_obs},
index=spatial_data.obs_names,
)
# Prepare cell type information
cell_types = reference_data.obs[data.cell_type_key].copy()
cell_types = cell_types.str.replace("/", "_", regex=False)
cell_types = cell_types.str.replace(" ", "_", regex=False)
# RCTD requires minimum 25 cells per cell type
MIN_CELLS_PER_TYPE = 25
cell_type_counts = cell_types.value_counts()
rare_types = cell_type_counts[
cell_type_counts < MIN_CELLS_PER_TYPE
].index.tolist()
if rare_types:
warnings.warn(
f"RCTD requires ≥{MIN_CELLS_PER_TYPE} cells per cell type. "
f"Filtering {len(rare_types)} rare types: {rare_types}",
UserWarning,
stacklevel=2,
)
keep_mask = ~cell_types.isin(rare_types)
reference_data = reference_data[keep_mask].copy()
cell_types = cell_types[keep_mask]
remaining_types = cell_types.unique()
if len(remaining_types) < 2:
raise DataError(
f"After filtering rare cell types, only {len(remaining_types)} "
f"cell type(s) remain. RCTD requires at least 2 cell types."
)
cell_types_series = pd.Series(
cell_types.values, index=reference_data.obs_names, name="cell_type"
)
# Calculate nUMI
spatial_numi = pd.Series(
np.asarray(spatial_data.X.sum(axis=1)).ravel(),
index=spatial_data.obs_names,
name="nUMI",
)
reference_numi = pd.Series(
np.asarray(reference_data.X.sum(axis=1)).ravel(),
index=reference_data.obs_names,
name="nUMI",
)
# Transfer matrices to R
with localconverter(ro.default_converter + anndata2ri.converter):
ro.globalenv["spatial_counts"] = spatial_data.X.T
ro.globalenv["reference_counts"] = reference_data.X.T
ro.globalenv["gene_names_spatial"] = ro.StrVector(spatial_data.var_names)
ro.globalenv["spot_names"] = ro.StrVector(spatial_data.obs_names)
ro.globalenv["gene_names_ref"] = ro.StrVector(reference_data.var_names)
ro.globalenv["cell_names"] = ro.StrVector(reference_data.obs_names)
ro.r(
"""
rownames(spatial_counts) <- gene_names_spatial
colnames(spatial_counts) <- spot_names
rownames(reference_counts) <- gene_names_ref
colnames(reference_counts) <- cell_names
"""
)
# Transfer other data
with localconverter(ro.default_converter + pandas2ri.converter):
ro.globalenv["coords"] = ro.conversion.py2rpy(coords)
ro.globalenv["numi_spatial"] = ro.conversion.py2rpy(spatial_numi)
ro.globalenv["cell_types_vec"] = ro.conversion.py2rpy(cell_types_series)
ro.globalenv["numi_ref"] = ro.conversion.py2rpy(reference_numi)
ro.globalenv["max_cores_val"] = max_cores
ro.globalenv["rctd_mode"] = mode
ro.globalenv["conf_thresh"] = confidence_threshold
ro.globalenv["doub_thresh"] = doublet_threshold
ro.globalenv["max_multi_types_val"] = max_multi_types
# Run RCTD in R
ro.r(
"""
puck <- SpatialRNA(coords, spatial_counts, numi_spatial)
cell_types_factor <- as.factor(cell_types_vec)
names(cell_types_factor) <- names(cell_types_vec)
reference <- Reference(reference_counts, cell_types_factor, numi_ref, min_UMI = 5)
myRCTD <- create.RCTD(puck, reference, max_cores = max_cores_val,
MAX_MULTI_TYPES = max_multi_types_val, UMI_min_sigma = 10)
myRCTD@config$CONFIDENCE_THRESHOLD <- conf_thresh
myRCTD@config$DOUBLET_THRESHOLD <- doub_thresh
myRCTD <- run.RCTD(myRCTD, doublet_mode = rctd_mode)
"""
)
# Extract results
proportions = _extract_rctd_results(mode)
# Validate results
if proportions.isna().any().any():
nan_count = proportions.isna().sum().sum()
warnings.warn(
f"RCTD produced {nan_count} NaN values", UserWarning, stacklevel=2
)
if (proportions < 0).any().any():
neg_count = (proportions < 0).sum().sum()
raise ProcessingError(f"RCTD error: {neg_count} negative values")
# Create statistics
stats = create_deconvolution_stats(
proportions,
data.common_genes,
method=f"RCTD-{mode}",
device="CPU",
mode=mode,
max_cores=max_cores,
confidence_threshold=confidence_threshold,
doublet_threshold=doublet_threshold,
)
# Clean up R global environment
ro.r(
"""
rm(list = c("spatial_counts", "reference_counts", "gene_names_spatial",
"spot_names", "gene_names_ref", "cell_names", "coords",
"numi_spatial", "cell_types_vec", "numi_ref", "max_cores_val",
"rctd_mode", "conf_thresh", "doub_thresh", "max_multi_types_val",
"puck", "cell_types_factor", "reference", "myRCTD",
"weights_matrix", "cell_type_names"),
envir = .GlobalEnv)
gc()
"""
)
return proportions, stats
except Exception as e:
if isinstance(e, (ParameterError, ProcessingError)):
raise
raise ProcessingError(f"RCTD deconvolution failed: {e}") from e
def _extract_rctd_results(mode: str) -> pd.DataFrame:
"""Extract RCTD results from R environment."""
import rpy2.robjects as ro
from rpy2.robjects import numpy2ri, pandas2ri
from rpy2.robjects.conversion import localconverter
with localconverter(
ro.default_converter + pandas2ri.converter + numpy2ri.converter
):
if mode == "full":
ro.r(
"""
weights_matrix <- myRCTD@results$weights
cell_type_names <- myRCTD@cell_type_info$renorm[[2]]
spot_names <- rownames(weights_matrix)
"""
)
elif mode == "doublet":
ro.r(
"""
if("weights_doublet" %in% names(myRCTD@results) && "results_df" %in% names(myRCTD@results)) {
weights_doublet <- myRCTD@results$weights_doublet
results_df <- myRCTD@results$results_df
cell_type_names <- myRCTD@cell_type_info$renorm[[2]]
spot_names <- rownames(results_df)
n_spots <- length(spot_names)
n_cell_types <- length(cell_type_names)
weights_matrix <- matrix(0, nrow = n_spots, ncol = n_cell_types)
rownames(weights_matrix) <- spot_names
colnames(weights_matrix) <- cell_type_names
for(i in 1:n_spots) {
spot_class <- results_df$spot_class[i]
if(spot_class %in% c("doublet_certain", "doublet_uncertain")) {
first_type <- as.character(results_df$first_type[i])
second_type <- as.character(results_df$second_type[i])
if(first_type %in% cell_type_names) {
first_idx <- which(cell_type_names == first_type)
weights_matrix[i, first_idx] <- weights_doublet[i, "first_type"]
}
if(second_type %in% cell_type_names && second_type != first_type) {
second_idx <- which(cell_type_names == second_type)
weights_matrix[i, second_idx] <- weights_doublet[i, "second_type"]
}
} else if(spot_class == "singlet") {
first_type <- as.character(results_df$first_type[i])
if(first_type %in% cell_type_names) {
first_idx <- which(cell_type_names == first_type)
weights_matrix[i, first_idx] <- 1.0
}
}
}
} else {
stop("Official doublet mode structures not found")
}
"""
)
else: # multi mode
ro.r(
"""
results_list <- myRCTD@results
spot_names <- colnames(myRCTD@spatialRNA@counts)
cell_type_names <- myRCTD@cell_type_info$renorm[[2]]
n_spots <- length(spot_names)
n_cell_types <- length(cell_type_names)
weights_matrix <- matrix(0, nrow = n_spots, ncol = n_cell_types)
rownames(weights_matrix) <- spot_names
colnames(weights_matrix) <- cell_type_names
for(i in 1:n_spots) {
spot_result <- results_list[[i]]
predicted_types <- spot_result$cell_type_list
proportions <- spot_result$sub_weights
for(j in seq_along(predicted_types)) {
cell_type <- predicted_types[j]
if(cell_type %in% cell_type_names) {
col_idx <- which(cell_type_names == cell_type)
weights_matrix[i, col_idx] <- proportions[j]
}
}
}
"""
)
weights_r = ro.r("as.matrix(weights_matrix)")
cell_type_names_r = ro.r("cell_type_names")
spot_names_r = ro.r("spot_names")
weights_array = ro.conversion.rpy2py(weights_r)
cell_type_names = ro.conversion.rpy2py(cell_type_names_r)
spot_names = ro.conversion.rpy2py(spot_names_r)
return pd.DataFrame(weights_array, index=spot_names, columns=cell_type_names)