Skip to main content
Glama

Wazuh MCP Server

by gensecaihq
wazuh_indexer_client.py23.8 kB
"""Production-grade Wazuh Indexer API client for querying alerts and vulnerabilities.""" import asyncio import time from datetime import datetime, timedelta from typing import Dict, Any, Optional, List from urllib.parse import urljoin import aiohttp import json import sys # Clean imports within the package from wazuh_mcp_server.config import WazuhConfig from wazuh_mcp_server.utils.exceptions import ( AuthenticationError, AuthorizationError, ConnectionError, APIError, RateLimitError, handle_api_error, handle_connection_error ) from wazuh_mcp_server.utils.logging import get_logger, log_performance, LogContext from wazuh_mcp_server.utils.rate_limiter import global_rate_limiter, RateLimitConfig from wazuh_mcp_server.utils.validation import validate_alert_query, sanitize_string from wazuh_mcp_server.utils.production_error_handler import production_error_handler from wazuh_mcp_server.utils.ssl_config import SSLConfigurationManager, SSLConfig from wazuh_mcp_server.utils.error_recovery import error_recovery_manager from wazuh_mcp_server.api.wazuh_field_mappings import WazuhFieldMapper, WazuhVersion logger = get_logger(__name__) class WazuhIndexerClient: """Production-grade client for Wazuh Indexer (OpenSearch/Elasticsearch) API.""" def __init__(self, config: WazuhConfig): self.config = config self.session: Optional[aiohttp.ClientSession] = None # Use indexer settings if available, otherwise fallback to server settings self.host = getattr(config, 'indexer_host', config.host) self.port = getattr(config, 'indexer_port', 9200) self.username = getattr(config, 'indexer_username', config.username) self.password = getattr(config, 'indexer_password', config.password) self.verify_ssl = getattr(config, 'indexer_verify_ssl', config.verify_ssl) self.base_url = f"https://{self.host}:{self.port}" # Initialize field mapper for schema compatibility wazuh_version = getattr(config, 'wazuh_version', None) if wazuh_version: # Extract major.minor version import re version_match = re.search(r'(\d+)\.(\d+)', wazuh_version) if version_match: major = int(version_match.group(1)) minor = int(version_match.group(2)) if major == 4 and minor == 8: self.field_mapper = WazuhFieldMapper(WazuhVersion.V4_8_X) elif major == 4 and minor >= 9: self.field_mapper = WazuhFieldMapper(WazuhVersion.V4_9_X) else: # For older or newer versions, default to 4.8.x self.field_mapper = WazuhFieldMapper(WazuhVersion.V4_8_X) else: self.field_mapper = WazuhFieldMapper(WazuhVersion.V4_8_X) else: self.field_mapper = WazuhFieldMapper(WazuhVersion.V4_8_X) # Default to 4.8.x # Configure rate limiting for indexer global_rate_limiter.configure_endpoint( "wazuh_indexer", RateLimitConfig(max_requests=200, time_window=60) # 200 requests per minute ) # Performance and health metrics self.request_count = 0 self.error_count = 0 self.last_successful_request = None self.connection_pool_stats = {} # SSL/TLS configuration validation self._validate_ssl_config() logger.info(f"Initialized production Wazuh Indexer client for {self.host}:{self.port}", extra={ "details": { "verify_ssl": self.verify_ssl, "field_mapper_version": self.field_mapper.version.value, "base_url": self.base_url } }) def _validate_ssl_config(self): """Validate SSL/TLS configuration using production-grade SSL manager.""" # Create SSL configuration from indexer settings ssl_config = SSLConfig( verify_ssl=self.verify_ssl, ca_bundle_path=getattr(self.config, 'indexer_ca_bundle_path', None), client_cert_path=getattr(self.config, 'indexer_client_cert_path', None), client_key_path=getattr(self.config, 'indexer_client_key_path', None), allow_self_signed=getattr(self.config, 'indexer_allow_self_signed', True), ssl_timeout=getattr(self.config, 'ssl_timeout', 30), auto_detect_issues=getattr(self.config, 'indexer_auto_detect_ssl_issues', True) ) # Initialize SSL manager and validate configuration self.ssl_manager = SSLConfigurationManager() is_valid = self.ssl_manager.validate_ssl_config(ssl_config) if not is_valid: logger.warning("SSL configuration has security issues", extra={ "details": { "host": self.host, "port": self.port, "verify_ssl": self.verify_ssl, "allow_self_signed": ssl_config.allow_self_signed } }) # Auto-detect SSL issues if enabled if ssl_config.auto_detect_issues and self.verify_ssl: try: detection_results = self.ssl_manager.auto_detect_ssl_issues(self.host, self.port) if detection_results.get("recommendations"): logger.info("SSL auto-detection completed", extra={ "details": { "ssl_available": detection_results.get("ssl_available"), "certificate_valid": detection_results.get("certificate_valid"), "self_signed": detection_results.get("self_signed"), "recommendations": detection_results.get("recommendations") } }) except Exception as e: logger.debug(f"SSL auto-detection failed: {e}") # Store SSL config for session creation self._ssl_config = ssl_config async def __aenter__(self): """Async context manager entry.""" await self._create_session() return self async def __aexit__(self, exc_type, exc_val, exc_tb): """Async context manager exit with proper cleanup.""" if self.session: await self.session.close() logger.debug("Wazuh Indexer client session closed") async def _create_session(self): """Create aiohttp session with production-grade SSL configuration.""" timeout = aiohttp.ClientTimeout(total=self.config.request_timeout_seconds) # Get SSL connector arguments from SSL manager connector_args = self.ssl_manager.get_aiohttp_connector_args(self._ssl_config) # Create connector with SSL configuration and connection limits connector = aiohttp.TCPConnector( limit=self.config.max_connections, limit_per_host=self.config.pool_size, **connector_args ) self.session = aiohttp.ClientSession( connector=connector, timeout=timeout, headers={"User-Agent": "WazuhMCP/2.1.0"}, auth=aiohttp.BasicAuth(self.username, self.password) ) logger.debug("Created Wazuh Indexer session with production-grade SSL configuration", extra={ "details": { "ssl_verification": self._ssl_config.verify_ssl, "max_connections": self.config.max_connections, "pool_size": self.config.pool_size } }) async def _request( self, method: str, endpoint: str, data: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """Make authenticated request to Wazuh Indexer with production error handling.""" async def _make_request() -> Dict[str, Any]: # Rate limit requests await global_rate_limiter.enforce_rate_limit("wazuh_indexer") url = urljoin(self.base_url, endpoint) request_id = f"indexer_req_{int(time.time() * 1000)}" with LogContext(request_id): logger.debug(f"Making {method} request to {endpoint}", extra={ "details": {"endpoint": endpoint, "has_data": data is not None} }) kwargs = { "headers": { "Content-Type": "application/json", "User-Agent": "WazuhMCP-Indexer/2.1.0" } } if data: kwargs["json"] = data start_time = time.time() async with self.session.request(method, url, **kwargs) as response: self.request_count += 1 response_time = time.time() - start_time # Log performance metrics logger.debug(f"Indexer request completed", extra={ "details": { "status": response.status, "response_time_ms": round(response_time * 1000, 2), "endpoint": endpoint } }) if response.status not in [200, 201]: response_data = None try: response_data = await response.json() except (aiohttp.ClientError, json.JSONDecodeError, ValueError): # Try to get text response for better error details try: response_text = await response.text() response_data = {"error": response_text} except (aiohttp.ClientError, UnicodeDecodeError): response_data = {"error": f"HTTP {response.status}"} # Enhanced error logging for production logger.error(f"Indexer API error: {response.status}", extra={ "details": { "status": response.status, "endpoint": endpoint, "response_data": response_data, "headers": dict(response.headers) } }) handle_api_error(response.status, response_data) result = await response.json() self.last_successful_request = datetime.utcnow() # Validate response structure self._validate_response_structure(result, endpoint) return result # Use production error handler with retry logic and recovery try: return await production_error_handler.execute_with_retry( _make_request, f"indexer_{method.lower()}", "indexer", endpoint ) except Exception as e: # Use error recovery manager for intelligent recovery recovery_result = await error_recovery_manager.handle_error( e, f"wazuh_indexer_{method.lower()}", retry_func=_make_request, context={"endpoint": endpoint, "method": method, "data": data} ) if recovery_result.get("success"): return recovery_result.get("data") else: # Re-raise if recovery failed raise def _validate_response_structure(self, response: Dict[str, Any], endpoint: str): """Validate Indexer API response structure for production compatibility.""" if "_search" in endpoint: # Validate search response structure if "hits" not in response: logger.warning(f"Unexpected search response structure: missing 'hits'", extra={ "details": { "endpoint": endpoint, "response_keys": list(response.keys()) } }) elif "total" not in response.get("hits", {}): logger.warning(f"Search response missing total count", extra={ "details": {"endpoint": endpoint} }) elif "_cluster/health" in endpoint: # Validate cluster health response expected_fields = ["status", "cluster_name", "number_of_nodes"] missing_fields = [field for field in expected_fields if field not in response] if missing_fields: logger.warning(f"Cluster health response missing fields: {missing_fields}", extra={ "details": {"endpoint": endpoint} }) @log_performance async def search_alerts( self, limit: int = 100, offset: int = 0, level: Optional[int] = None, sort: str = "-timestamp", time_range: Optional[int] = None, agent_id: Optional[str] = None ) -> Dict[str, Any]: """Search alerts in Wazuh Indexer with production-grade field mapping.""" # Use field mapper for proper index pattern index_pattern = self.field_mapper.get_index_pattern("alerts") # Map sort field using field mapper mapped_sort_field = self.field_mapper.get_sort_field(sort) sort_field, sort_order = (mapped_sort_field[1:], "desc") if mapped_sort_field.startswith("-") else (mapped_sort_field, "asc") # Build production-grade Elasticsearch query query = { "size": min(limit, self.config.max_alerts_per_query), "from": offset, "sort": [ { sort_field: { "order": sort_order, "unmapped_type": "date" if "timestamp" in sort_field else "keyword" } } ], "query": { "bool": { "must": [], "filter": [] } }, "_source": { "excludes": ["full_log"] # Exclude large fields for performance } } # Add filters using field mapper if level is not None: rule_level_field = self.field_mapper.map_server_to_indexer_field("rule.level", "alert") query["query"]["bool"]["filter"].append({ "term": {rule_level_field: level} }) if time_range: end_time = datetime.utcnow() start_time = end_time - timedelta(seconds=time_range) timestamp_field = self.field_mapper.map_server_to_indexer_field("timestamp", "alert") query["query"]["bool"]["filter"].append({ "range": { timestamp_field: { "gte": start_time.isoformat() + "Z", "lte": end_time.isoformat() + "Z", "format": "strict_date_optional_time" } } }) if agent_id: clean_agent_id = sanitize_string(agent_id, 20) agent_id_field = self.field_mapper.map_server_to_indexer_field("agent.id", "alert") query["query"]["bool"]["filter"].append({ "term": {f"{agent_id_field}.keyword": clean_agent_id} # Use keyword field for exact match }) # If no filters, use match_all if not query["query"]["bool"]["must"] and not query["query"]["bool"]["filter"]: query["query"] = { "match_all": {} } # Add aggregations for monitoring and debugging query["aggs"] = { "rule_levels": { "terms": { "field": "rule.level", "size": 10 } }, "top_agents": { "terms": { "field": "agent.name.keyword", "size": 5 } } } logger.info(f"Searching alerts in Indexer", extra={ "details": { "index_pattern": index_pattern, "limit": limit, "level": level, "agent_id": agent_id, "sort_field": sort_field, "query_size": len(str(query)) } }) result = await self._request("POST", f"/{index_pattern}/_search", data=query) # Validate result before transformation issues = self._validate_alert_response(result) if issues: logger.warning(f"Alert response validation issues: {issues}") # Transform to match Server API format return self._transform_alerts_response(result) def _validate_alert_response(self, response: Dict[str, Any]) -> List[str]: """Validate alert search response using field mapper.""" return self.field_mapper.validate_field_compatibility( response.get("hits", {}).get("hits", [{}])[0].get("_source", {}), "alert" ) @log_performance async def search_vulnerabilities( self, agent_id: Optional[str] = None, cve_id: Optional[str] = None, limit: int = 100 ) -> Dict[str, Any]: """Search vulnerabilities in Wazuh Indexer.""" query = { "size": min(limit, 1000), "sort": [ { "@timestamp": { "order": "desc" } } ], "query": { "bool": { "must": [] } } } # Add filters if agent_id: clean_agent_id = sanitize_string(agent_id, 20) query["query"]["bool"]["must"].append({ "term": {"agent.id": clean_agent_id} }) if cve_id: clean_cve_id = sanitize_string(cve_id, 50) query["query"]["bool"]["must"].append({ "term": {"vulnerability.id": clean_cve_id} }) # If no filters, match all if not query["query"]["bool"]["must"]: query["query"] = {"match_all": {}} logger.info(f"Searching vulnerabilities in Indexer", extra={ "details": {"agent_id": agent_id, "cve_id": cve_id, "limit": limit} }) # Use field mapper to get proper index pattern vuln_index = self.field_mapper.get_index_pattern("vulnerabilities") result = await self._request("POST", f"/{vuln_index}/_search", data=query) # Transform to match Server API format return self._transform_vulnerabilities_response(result) def _transform_alerts_response(self, indexer_response: Dict[str, Any]) -> Dict[str, Any]: """Transform Indexer response to match Server API format.""" hits = indexer_response.get("hits", {}) total = hits.get("total", {}) # Handle different total formats total_count = total.get("value", 0) if isinstance(total, dict) else total alerts = [] for hit in hits.get("hits", []): source = hit.get("_source", {}) alerts.append(source) return { "data": { "affected_items": alerts, "total_affected_items": total_count, "total_failed_items": 0, "failed_items": [] }, "message": "Alerts retrieved successfully from Indexer", "error": 0 } def _transform_vulnerabilities_response(self, indexer_response: Dict[str, Any]) -> Dict[str, Any]: """Transform vulnerabilities response to match Server API format.""" hits = indexer_response.get("hits", {}) total = hits.get("total", {}) # Handle different total formats total_count = total.get("value", 0) if isinstance(total, dict) else total vulnerabilities = [] for hit in hits.get("hits", []): source = hit.get("_source", {}) vulnerabilities.append(source) return { "data": { "affected_items": vulnerabilities, "total_affected_items": total_count, "total_failed_items": 0, "failed_items": [] }, "message": "Vulnerabilities retrieved successfully from Indexer", "error": 0 } async def health_check(self) -> Dict[str, Any]: """Perform health check of Wazuh Indexer with error recovery.""" async def _do_health_check() -> Dict[str, Any]: result = await self._request("GET", "/_cluster/health") health_data = { "status": "healthy" if result.get("status") in ["green", "yellow"] else "unhealthy", "cluster_name": result.get("cluster_name", "unknown"), "cluster_status": result.get("status", "unknown"), "number_of_nodes": result.get("number_of_nodes", 0), "client_stats": { "total_requests": self.request_count, "error_count": self.error_count, "error_rate_percentage": round((self.error_count / max(self.request_count, 1)) * 100, 2) }, "last_check": datetime.utcnow().isoformat() } logger.info("Indexer health check passed", extra={"details": health_data}) return health_data try: return await _do_health_check() except Exception as e: # Use error recovery for health checks recovery_result = await error_recovery_manager.handle_error( e, "wazuh_indexer_health_check", retry_func=_do_health_check, context={"host": self.host, "port": self.port} ) if recovery_result.get("success"): return recovery_result.get("data") # Fallback to degraded health status health_data = { "status": "unhealthy", "error": str(e), "client_stats": { "total_requests": self.request_count, "error_count": self.error_count }, "last_check": datetime.utcnow().isoformat(), "recovery_attempted": True } logger.error("Indexer health check failed after recovery attempts", extra={"details": health_data}) return health_data def get_metrics(self) -> Dict[str, Any]: """Get client performance metrics.""" return { "total_requests": self.request_count, "error_count": self.error_count, "error_rate": (self.error_count / max(self.request_count, 1)) * 100, "base_url": self.base_url }

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/gensecaihq/Wazuh-MCP-Server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server