"""gRPC server implementation for MCP Server.
This module provides a simplified gRPC-like server using asyncio.
In production, use proper gRPC code generation from .proto files.
"""
import asyncio
import json
import logging
import uuid
from pathlib import Path
from typing import Dict, Optional
import yaml
from .clients.remediation_client import RemediationClient
from .config import Settings, get_settings
from .logging_config import setup_logging
from .models.scenario import Scenario
from .orchestration.engine import run_scenario
from .orchestration.fsm import ScenarioContext, State
from .services.eval_service import EvalService
from .services.executor_service import ExecutorService
from .services.fault_service import FaultService
logger = logging.getLogger(__name__)
class ScenarioServiceImpl:
"""Implementation of ScenarioService."""
def __init__(
self,
settings: Settings,
remediation_client: RemediationClient,
fault_service: FaultService,
executor_service: ExecutorService,
eval_service: EvalService,
):
"""Initialize scenario service."""
self.settings = settings
self.remediation_client = remediation_client
self.fault_service = fault_service
self.executor_service = executor_service
self.eval_service = eval_service
# Track active runs
self.active_runs: Dict[str, ScenarioContext] = {}
self.scenario_registry: Dict[str, str] = {} # id -> yaml
async def run_scenario(
self,
scenario_yaml: str,
scenario_id: Optional[str] = None,
bindings: Optional[Dict[str, str]] = None,
) -> Dict:
"""
Run a scenario.
Args:
scenario_yaml: YAML scenario definition
scenario_id: Optional scenario ID (for registry lookup)
bindings: Variable bindings
Returns:
Dict with run_id, status, report_uri
"""
# Load scenario
if scenario_id and not scenario_yaml:
scenario_yaml = self.scenario_registry.get(scenario_id, "")
if not scenario_yaml:
return {
"run_id": "",
"status": {"code": State.FAIL.value, "message": "No scenario provided"},
"report_uri": "",
}
try:
scenario_data = yaml.safe_load(scenario_yaml)
scenario = Scenario(**scenario_data)
except Exception as e:
logger.error(f"Failed to load scenario: {e}")
return {
"run_id": "",
"status": {
"code": State.FAIL.value,
"message": f"Invalid scenario: {e}",
},
"report_uri": "",
}
# Create context
run_id = f"run-{uuid.uuid4().hex[:12]}"
context = ScenarioContext(
run_id=run_id,
scenario=scenario,
bindings=bindings or {},
)
self.active_runs[run_id] = context
logger.info(f"Starting scenario run: {run_id}")
# Run scenario asynchronously
asyncio.create_task(self._execute_scenario(context))
return {
"run_id": run_id,
"status": {"code": State.INIT.value, "message": "Scenario started"},
"report_uri": f"{self.settings.log_dir}/runs/{run_id}/report.json",
}
async def _execute_scenario(self, context: ScenarioContext) -> None:
"""Execute scenario and handle completion."""
try:
async for step_result in run_scenario(
context=context,
remediation_client=self.remediation_client,
fault_service=self.fault_service,
executor_service=self.executor_service,
eval_service=self.eval_service,
log_dir=self.settings.log_dir,
):
logger.info(
f"[{context.run_id}] {step_result.state.value}: {step_result.message}"
)
except Exception as e:
logger.error(f"Scenario execution failed: {e}", exc_info=True)
async def list_scenarios(self) -> Dict:
"""List registered scenarios."""
return {"scenario_ids": list(self.scenario_registry.keys())}
async def get_scenario(self, scenario_id: str) -> Dict:
"""Get a scenario by ID."""
scenario_yaml = self.scenario_registry.get(scenario_id, "")
if not scenario_yaml:
return {
"scenario_yaml": "",
"status": {
"code": State.FAIL.value,
"message": f"Scenario not found: {scenario_id}",
},
}
return {
"scenario_yaml": scenario_yaml,
"status": {"code": State.PASS.value, "message": "Success"},
}
def register_scenario(self, scenario_id: str, scenario_yaml: str) -> None:
"""Register a scenario."""
self.scenario_registry[scenario_id] = scenario_yaml
logger.info(f"Registered scenario: {scenario_id}")
class MCPServer:
"""Main MCP Server."""
def __init__(self, settings: Settings):
"""Initialize MCP server."""
self.settings = settings
# Initialize services
self.remediation_client = RemediationClient(
base_url=settings.http.base_url,
timeout=settings.grpc.timeout,
)
self.fault_service = FaultService()
self.executor_service = ExecutorService(log_dir=settings.log_dir)
self.eval_service = EvalService()
# Initialize scenario service
self.scenario_service = ScenarioServiceImpl(
settings=settings,
remediation_client=self.remediation_client,
fault_service=self.fault_service,
executor_service=self.executor_service,
eval_service=self.eval_service,
)
async def start(self) -> None:
"""Start the server."""
logger.info(
f"Starting MCP Server on {self.settings.grpc.host}:{self.settings.grpc.port}"
)
# Load scenarios from scenarios directory
await self._load_scenarios()
logger.info("MCP Server started successfully")
logger.info(f"Log directory: {self.settings.log_dir}")
# Keep server running
try:
while True:
await asyncio.sleep(1)
except KeyboardInterrupt:
logger.info("Shutting down...")
await self.stop()
async def stop(self) -> None:
"""Stop the server."""
logger.info("Stopping MCP Server")
await self.remediation_client.close()
logger.info("MCP Server stopped")
async def _load_scenarios(self) -> None:
"""Load scenarios from scenarios directory."""
scenarios_dir = Path("scenarios")
if not scenarios_dir.exists():
logger.warning(f"Scenarios directory not found: {scenarios_dir}")
return
for scenario_file in scenarios_dir.glob("*.yaml"):
try:
with open(scenario_file, "r") as f:
scenario_yaml = f.read()
scenario_data = yaml.safe_load(scenario_yaml)
scenario_id = scenario_data.get("meta", {}).get("id", "")
if scenario_id:
self.scenario_service.register_scenario(scenario_id, scenario_yaml)
else:
logger.warning(f"Scenario missing ID: {scenario_file}")
except Exception as e:
logger.error(f"Failed to load scenario {scenario_file}: {e}")
async def serve():
"""Main entry point for MCP Server."""
# Load settings
settings = get_settings()
# Setup logging
setup_logging(settings.log_dir)
logger.info("=" * 60)
logger.info("MCP Server - AI-driven Remediation Testing")
logger.info("=" * 60)
logger.info(f"Version: 1.0.0")
logger.info(f"gRPC: {settings.grpc.host}:{settings.grpc.port}")
logger.info(f"HTTP API: {settings.http.base_url}")
logger.info(f"Log directory: {settings.log_dir}")
logger.info("=" * 60)
# Create and start server
server = MCPServer(settings)
try:
await server.start()
except Exception as e:
logger.error(f"Server error: {e}", exc_info=True)
await server.stop()
def main():
"""Entry point."""
asyncio.run(serve())
if __name__ == "__main__":
main()