Skip to main content
Glama
rna_inverse_design.py14.1 kB
#!/usr/bin/env python3 """ Script: rna_inverse_design.py Description: RNA Inverse Design using gRNAde - Generate RNA sequences that fold into specified structures Original Use Case: examples/use_case_1_rna_inverse_design_fixed.py Dependencies Removed: Inlined constants, simplified repo imports Usage: python scripts/rna_inverse_design.py --secondary_structure "((((....))))" --mode 2d --n_pass 5 --output_dir results/rna_design Example: python scripts/rna_inverse_design.py --secondary_structure "((((....))))" --mode 2d --output_dir results/test """ # ============================================================================== # Minimal Imports (only essential packages) # ============================================================================== import argparse import os import random from pathlib import Path from typing import Union, Optional, Dict, Any import json # Essential scientific packages import numpy as np import pandas as pd import torch import torch.nn.functional as F from datetime import datetime # ============================================================================== # Configuration (extracted from use case) # ============================================================================== DEFAULT_CONFIG = { "model_checkpoint": "gRNAde_drop3d@0.75_maxlen@500.h5", "mode": "2d", # "2d" or "3d" "total_samples": 1000, "n_samples": 32, "n_pass": 100, "temperature_min": 0.1, "temperature_max": 1.0, "pass_threshold": 80, "seed": 42, "device": "cpu" # or "cuda" if available } # ============================================================================== # Inlined Constants (from repo/geometric-rna-design/src/constants.py) # ============================================================================== # RNA nucleotides mapping LETTER_TO_NUM = {"A": 0, "G": 1, "C": 2, "U": 3} NUM_TO_LETTER = {0: "A", 1: "G", 2: "C", 3: "U"} FILL_VALUE = 1e-5 # ============================================================================== # Inlined Utility Functions (simplified from repo) # ============================================================================== def set_seed(seed: int) -> None: """Set random seeds for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) def create_partial_seq_logit_bias(partial_seq, device, model_out_dim=4): """Create logit bias from partial sequence constraints.""" if partial_seq is None: return None bias = torch.zeros(len(partial_seq), model_out_dim, device=device) nucleotide_map = {'A': 0, 'G': 1, 'C': 2, 'U': 3} for i, nucleotide in enumerate(partial_seq): if nucleotide in nucleotide_map: # Set high bias for the specified nucleotide, very negative for others bias[i, :] = -1e10 bias[i, nucleotide_map[nucleotide]] = 0 return bias def get_repo_imports(): """Lazy load repo imports to minimize startup time.""" import sys from pathlib import Path # Add repo to path repo_path = Path(__file__).parent.parent / "repo" / "geometric-rna-design" sys.path.insert(0, str(repo_path)) try: from src.data.featurizer import RNAGraphFeaturizer from src.models import gRNAde from src.evaluator import ( openknot_score_ribonanzanet, self_consistency_score_ribonanzanet, self_consistency_score_ribonanzanet_sec_struct ) from tools.ribonanzanet.network import RibonanzaNet from tools.ribonanzanet_sec_struct.network import RibonanzaNetSS return { "RNAGraphFeaturizer": RNAGraphFeaturizer, "gRNAde": gRNAde, "openknot_score_ribonanzanet": openknot_score_ribonanzanet, "self_consistency_score_ribonanzanet": self_consistency_score_ribonanzanet, "self_consistency_score_ribonanzanet_sec_struct": self_consistency_score_ribonanzanet_sec_struct, "RibonanzaNet": RibonanzaNet, "RibonanzaNetSS": RibonanzaNetSS } except ImportError as e: raise ImportError(f"Could not import gRNAde modules: {e}. " f"Make sure repo is at: {repo_path}") def load_model_checkpoint(config: Dict[str, Any]): """Load gRNAde model from checkpoint.""" imports = get_repo_imports() # Model checkpoint path script_dir = Path(__file__).parent mcp_root = script_dir.parent # Try multiple possible locations possible_paths = [ mcp_root / "models" / config["model_checkpoint"], mcp_root / "repo" / "geometric-rna-design" / "models" / config["model_checkpoint"], Path(config["model_checkpoint"]) # If absolute path provided ] model_path = None for path in possible_paths: if path.exists(): model_path = path break if model_path is None: raise FileNotFoundError( f"Model checkpoint not found. Tried: {[str(p) for p in possible_paths]}" ) # Load model device = torch.device(config.get("device", "cpu")) model = imports["gRNAde"]( model_path=str(model_path), device=device, num_layers=20, # Default from repo model_type=config["mode"] ) return model, imports def sample_sequences(model, featurizer, structure_data, config: Dict[str, Any]): """Sample RNA sequences using the model.""" device = torch.device(config.get("device", "cpu")) # Prepare data data = featurizer(structure_data).to(device) sequences = [] perplexities = [] for pass_idx in range(config["n_pass"]): # Sample temperature temperature = np.random.uniform( config["temperature_min"], config["temperature_max"] ) # Generate sequences for this pass with torch.no_grad(): logits = model(data) # Apply temperature logits = logits / temperature # Sample sequences probs = F.softmax(logits, dim=-1) sampled_indices = torch.multinomial(probs.view(-1, probs.size(-1)), 1) sampled_indices = sampled_indices.view(probs.shape[:-1]) # Convert to sequences seq_length = sampled_indices.size(0) for i in range(min(config["n_samples"], seq_length)): indices = sampled_indices[:, i] if len(sampled_indices.shape) > 1 else sampled_indices sequence = ''.join([NUM_TO_LETTER[idx.item()] for idx in indices]) # Calculate perplexity log_probs = F.log_softmax(logits, dim=-1) perplexity = torch.exp(-log_probs.mean()).item() sequences.append(sequence) perplexities.append(perplexity) return sequences[:config["total_samples"]], perplexities[:config["total_samples"]] # ============================================================================== # Core Function (main logic extracted from use case) # ============================================================================== def run_rna_inverse_design( pdb_file: Optional[Union[str, Path]] = None, secondary_structure: Optional[str] = None, partial_seq: Optional[str] = None, mode: str = "2d", output_dir: Optional[Union[str, Path]] = None, config: Optional[Dict[str, Any]] = None, **kwargs ) -> Dict[str, Any]: """ Generate RNA sequences using gRNAde inverse design. Args: pdb_file: Path to PDB file (required for 3d mode) secondary_structure: Secondary structure in dot-bracket notation (required for 2d mode) partial_seq: Partial sequence constraints (optional) mode: Design mode - "2d" or "3d" output_dir: Directory to save results (optional) config: Configuration dict (uses DEFAULT_CONFIG if not provided) **kwargs: Override specific config parameters Returns: Dict containing: - sequences: List of generated sequences - perplexities: List of perplexity scores - config_used: Configuration used for generation - output_file: Path to output file (if saved) - metadata: Execution metadata Example: >>> result = run_rna_inverse_design( ... secondary_structure="((((....))))", ... mode="2d", ... output_dir="results" ... ) >>> print(f"Generated {len(result['sequences'])} sequences") """ # Setup configuration config = {**DEFAULT_CONFIG, **(config or {}), **kwargs} config["mode"] = mode # Set seed for reproducibility set_seed(config["seed"]) # Validate inputs if mode == "3d" and pdb_file is None: raise ValueError("PDB file required for 3d mode") if mode == "2d" and secondary_structure is None: raise ValueError("Secondary structure required for 2d mode") # Load model and imports model, imports = load_model_checkpoint(config) featurizer = imports["RNAGraphFeaturizer"]() # Prepare structure data if mode == "3d": # For 3D mode, we need to load PDB data # This is simplified - the full implementation would use pdb_to_tensor pdb_file = Path(pdb_file) if not pdb_file.exists(): raise FileNotFoundError(f"PDB file not found: {pdb_file}") # Placeholder for PDB data - would need full featurizer implementation structure_data = { "pdb_path": str(pdb_file), "mode": "3d" } else: # For 2D mode, use secondary structure structure_data = { "secondary_structure": secondary_structure, "mode": "2d" } try: # Generate sequences sequences, perplexities = sample_sequences(model, featurizer, structure_data, config) # Prepare results results_df = pd.DataFrame({ "sequence": sequences, "perplexity": perplexities, "temperature": [np.random.uniform(config["temperature_min"], config["temperature_max"]) for _ in sequences], "seed": [config["seed"]] * len(sequences), "mode": [mode] * len(sequences), "length": [len(seq) for seq in sequences] }) # Save output if requested output_path = None if output_dir: output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_filename = f"rna_designs_{mode}_{timestamp}.csv" output_path = output_dir / output_filename results_df.to_csv(output_path, index=False) return { "sequences": sequences, "perplexities": perplexities, "results_df": results_df, "config_used": config, "output_file": str(output_path) if output_path else None, "metadata": { "mode": mode, "num_sequences": len(sequences), "pdb_file": str(pdb_file) if pdb_file else None, "secondary_structure": secondary_structure, "timestamp": datetime.now().isoformat() } } except Exception as e: # Return error information for debugging return { "success": False, "error": str(e), "error_type": type(e).__name__, "config_used": config, "metadata": { "mode": mode, "timestamp": datetime.now().isoformat() } } # ============================================================================== # CLI Interface # ============================================================================== def main(): parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter ) parser.add_argument('--pdb', help='PDB file path (required for 3d mode)') parser.add_argument('--secondary_structure', help='Secondary structure in dot-bracket notation (required for 2d mode)') parser.add_argument('--partial_seq', help='Partial sequence constraints') parser.add_argument('--mode', choices=['2d', '3d'], default='2d', help='Design mode') parser.add_argument('--output_dir', help='Output directory') parser.add_argument('--config', help='Config file (JSON)') # Model parameters parser.add_argument('--total_samples', type=int, help='Total number of sequences to generate') parser.add_argument('--n_pass', type=int, help='Number of passes') parser.add_argument('--temperature_min', type=float, help='Minimum temperature') parser.add_argument('--temperature_max', type=float, help='Maximum temperature') parser.add_argument('--seed', type=int, help='Random seed') args = parser.parse_args() # Load config if provided config = None if args.config: with open(args.config) as f: config = json.load(f) # Prepare arguments kwargs = {} for arg in ['total_samples', 'n_pass', 'temperature_min', 'temperature_max', 'seed']: value = getattr(args, arg) if value is not None: kwargs[arg] = value # Run design result = run_rna_inverse_design( pdb_file=args.pdb, secondary_structure=args.secondary_structure, partial_seq=args.partial_seq, mode=args.mode, output_dir=args.output_dir, config=config, **kwargs ) if result.get('success', True): print(f"✅ Success: Generated {len(result.get('sequences', []))} sequences") if result.get('output_file'): print(f"📁 Output saved to: {result['output_file']}") else: print(f"❌ Error: {result.get('error', 'Unknown error')}") return 1 return 0 if __name__ == '__main__': exit(main())

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/Biomolecular-Design-Nexus/grnade_mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server