Skip to main content
Glama
enkryptai

Enkrypt AI Secure MCP Gateway

Official
by enkryptai
enkrypt_provider.py58.9 kB
"""Enkrypt guardrail provider.""" import asyncio import time from typing import Any, ClassVar, Dict, List, Optional import aiohttp from secure_mcp_gateway.error_handling import error_handling_context, error_logger from secure_mcp_gateway.exceptions import ( ErrorCode, ErrorContext, ErrorSeverity, GuardrailError, create_guardrail_error, ) from secure_mcp_gateway.plugins.guardrails.base import ( GuardrailAction, GuardrailProvider, GuardrailRequest, GuardrailResponse, GuardrailViolation, InputGuardrail, OutputGuardrail, PIIHandler, ServerRegistrationRequest, ToolRegistrationRequest, ViolationType, ) from secure_mcp_gateway.utils import logger class EnkryptInputGuardrail: """ Enkrypt implementation of InputGuardrail. This class is fully self-contained and makes direct API calls to Enkrypt. """ def __init__(self, config: Dict[str, Any], api_key: str, base_url: str): self.config = config self.api_key = api_key self.base_url = base_url self.policy_name = config.get("policy_name", "") self.block_list = config.get("block", []) self.additional_config = config.get("additional_config", {}) # API endpoints self.guardrail_url = f"{base_url}/guardrails/policy/detect" # Debug mode self.debug = config.get("debug", False) async def validate(self, request: GuardrailRequest) -> GuardrailResponse: """Validate input using Enkrypt guardrails.""" start_time = time.time() # Create error context for correlation tracking context = ErrorContext( operation="input_guardrail_validation", server_name=getattr(request, "server_name", None), tool_name=request.tool_name, additional_context={ "policy_name": self.policy_name, "content_length": len(request.content), }, ) async with error_handling_context("input_guardrail_validation", context): try: # Prepare payload payload = {"text": request.content} headers = { "X-Enkrypt-Policy": self.policy_name, "apikey": self.api_key, "Content-Type": "application/json", } if self.debug: logger.debug( f"[EnkryptInputGuardrail] Validating with policy: {self.policy_name}" ) logger.debug(f"[EnkryptInputGuardrail] Payload: {payload}") # Make API call async with aiohttp.ClientSession() as session: async with session.post( self.guardrail_url, json=payload, headers=headers ) as response: resp_json = await response.json() if self.debug: logger.debug(f"[EnkryptInputGuardrail] Response: {resp_json}") # Check for API errors if resp_json.get("error"): error = create_guardrail_error( code=ErrorCode.GUARDRAIL_API_ERROR, message=f"Enkrypt API error: {resp_json.get('error')}", context=context, ) error_logger.log_error(error) return GuardrailResponse( is_safe=False, action=GuardrailAction.BLOCK, violations=[ GuardrailViolation( violation_type=ViolationType.CUSTOM, severity=1.0, message=f"API Error: {resp_json.get('error')}", action=GuardrailAction.BLOCK, metadata={"error": resp_json.get("error")}, ) ], metadata={ "api_error": True, "correlation_id": context.correlation_id, }, processing_time_ms=(time.time() - start_time) * 1000, ) # Parse violations from Enkrypt response violations = [] violations_detected = False if "summary" in resp_json: summary = resp_json["summary"] for policy_type in self.block_list: value = summary.get(policy_type) if value == 1 or (isinstance(value, list) and len(value) > 0): violations_detected = True violations.append( GuardrailViolation( violation_type=self._map_violation_type(policy_type), severity=0.8, # Default severity message=f"Input validation failed: {policy_type}", action=GuardrailAction.BLOCK, metadata={ "policy_type": policy_type, "value": value, "details": resp_json.get("details", {}).get( policy_type, {} ), }, ) ) # Determine overall safety is_safe = not violations_detected action = GuardrailAction.ALLOW if is_safe else GuardrailAction.BLOCK processing_time_ms = (time.time() - start_time) * 1000 return GuardrailResponse( is_safe=is_safe, action=action, violations=violations, modified_content=None, metadata={ "policy_name": self.policy_name, "enkrypt_response": resp_json, }, processing_time_ms=processing_time_ms, ) except Exception as e: error = create_guardrail_error( code=ErrorCode.GUARDRAIL_VALIDATION_FAILED, message=f"Input guardrail validation failed: {e!s}", context=context, cause=e, ) error_logger.log_error(error) processing_time_ms = (time.time() - start_time) * 1000 return GuardrailResponse( is_safe=False, action=GuardrailAction.BLOCK, violations=[ GuardrailViolation( violation_type=ViolationType.CUSTOM, severity=1.0, message=f"Validation error: {e!s}", action=GuardrailAction.BLOCK, metadata={ "exception": str(e), "correlation_id": context.correlation_id, }, ) ], metadata={ "exception": str(e), "correlation_id": context.correlation_id, "error_code": error.code.value, }, processing_time_ms=processing_time_ms, ) def get_supported_detectors(self) -> List[ViolationType]: """Get supported violation types for input.""" return [ ViolationType.PII, ViolationType.INJECTION_ATTACK, ViolationType.TOXIC_CONTENT, ViolationType.NSFW_CONTENT, ViolationType.KEYWORD_VIOLATION, ViolationType.POLICY_VIOLATION, ViolationType.BIAS, ViolationType.SPONGE_ATTACK, ] def _map_violation_type(self, enkrypt_type: str) -> ViolationType: """Map Enkrypt violation types to standard ViolationType enum.""" mapping = { "pii": ViolationType.PII, "injection_attack": ViolationType.INJECTION_ATTACK, "toxicity": ViolationType.TOXIC_CONTENT, "nsfw": ViolationType.NSFW_CONTENT, "keyword_detector": ViolationType.KEYWORD_VIOLATION, "policy_violation": ViolationType.POLICY_VIOLATION, "bias": ViolationType.BIAS, "sponge_attack": ViolationType.SPONGE_ATTACK, } return mapping.get(enkrypt_type, ViolationType.CUSTOM) class EnkryptOutputGuardrail: """ Enkrypt implementation of OutputGuardrail. This class is fully self-contained and makes direct API calls to Enkrypt. Includes ALL checks: policy, relevancy, adherence, hallucination. """ def __init__(self, config: Dict[str, Any], api_key: str, base_url: str): self.config = config self.api_key = api_key self.base_url = base_url self.policy_name = config.get("policy_name", "") self.block_list = config.get("block", []) self.additional_config = config.get("additional_config", {}) # API endpoints self.guardrail_url = f"{base_url}/guardrails/policy/detect" self.relevancy_url = f"{base_url}/guardrails/relevancy" self.adherence_url = f"{base_url}/guardrails/adherence" self.hallucination_url = f"{base_url}/guardrails/hallucination" # Thresholds self.relevancy_threshold = self.additional_config.get( "relevancy_threshold", 0.7 ) self.adherence_threshold = self.additional_config.get( "adherence_threshold", 0.8 ) # Debug mode self.debug = config.get("debug", False) async def validate( self, response_content: str, original_request: GuardrailRequest ) -> GuardrailResponse: """ Validate output using Enkrypt guardrails with ALL checks. Performs: 1. Policy detection (if enabled) 2. Relevancy check (if enabled) 3. Adherence check (if enabled) 4. Hallucination check (if enabled) """ start_time = time.time() try: violations = [] additional_metadata = {} # 1. Policy Detection (if enabled) if self.config.get("enabled", False): policy_result = await self._check_policy(response_content) additional_metadata["policy"] = policy_result if "summary" in policy_result: summary = policy_result["summary"] for policy_type in self.block_list: value = summary.get(policy_type) if value == 1 or (isinstance(value, list) and len(value) > 0): violations.append( GuardrailViolation( violation_type=self._map_violation_type( policy_type ), severity=0.8, message=f"Output validation failed: {policy_type}", action=GuardrailAction.BLOCK, metadata={ "policy_type": policy_type, "value": value, "details": policy_result.get("details", {}).get( policy_type, {} ), }, ) ) # 2. Relevancy Check (if enabled) if self.additional_config.get("relevancy", False): relevancy_result = await self._check_relevancy( original_request.content, response_content ) additional_metadata["relevancy"] = relevancy_result relevancy_score = relevancy_result.get("score", 1.0) if relevancy_score < self.relevancy_threshold: violations.append( GuardrailViolation( violation_type=ViolationType.RELEVANCY_FAILURE, severity=1.0 - relevancy_score, message=f"Response not relevant (score: {relevancy_score:.2f})", action=GuardrailAction.WARN, metadata=relevancy_result, ) ) # 3. Adherence Check (if enabled) if self.additional_config.get("adherence", False): adherence_result = await self._check_adherence( original_request.content, response_content ) additional_metadata["adherence"] = adherence_result adherence_score = adherence_result.get("score", 1.0) if adherence_score < self.adherence_threshold: violations.append( GuardrailViolation( violation_type=ViolationType.ADHERENCE_FAILURE, severity=1.0 - adherence_score, message=f"Response doesn't adhere to context (score: {adherence_score:.2f})", action=GuardrailAction.WARN, metadata=adherence_result, ) ) # 4. Hallucination Check (if enabled) if self.additional_config.get("hallucination", False): hallucination_result = await self._check_hallucination( original_request.content, response_content ) additional_metadata["hallucination"] = hallucination_result if hallucination_result.get("has_hallucination", False): violations.append( GuardrailViolation( violation_type=ViolationType.HALLUCINATION, severity=hallucination_result.get("confidence", 0.5), message="Potential hallucination detected", action=GuardrailAction.WARN, metadata=hallucination_result, ) ) # Determine overall safety # Block if there are blocking violations, warn otherwise has_blocking_violations = any( v.action == GuardrailAction.BLOCK for v in violations ) is_safe = len(violations) == 0 action = ( GuardrailAction.BLOCK if has_blocking_violations else ( GuardrailAction.WARN if len(violations) > 0 else GuardrailAction.ALLOW ) ) processing_time_ms = (time.time() - start_time) * 1000 return GuardrailResponse( is_safe=is_safe, action=action, violations=violations, modified_content=None, metadata=additional_metadata, processing_time_ms=processing_time_ms, ) except Exception as e: logger.error(f"[EnkryptOutputGuardrail] Exception: {e}") processing_time_ms = (time.time() - start_time) * 1000 return GuardrailResponse( is_safe=False, action=GuardrailAction.BLOCK, violations=[ GuardrailViolation( violation_type=ViolationType.CUSTOM, severity=1.0, message=f"Validation error: {e!s}", action=GuardrailAction.BLOCK, metadata={"exception": str(e)}, ) ], metadata={"exception": str(e)}, processing_time_ms=processing_time_ms, ) async def _check_policy(self, text: str) -> Dict[str, Any]: """Check against policy using Enkrypt API.""" try: payload = {"text": text} headers = { "X-Enkrypt-Policy": self.policy_name, "apikey": self.api_key, "Content-Type": "application/json", } if self.debug: logger.debug( f"[EnkryptOutputGuardrail] Policy check for: {self.policy_name}" ) async with aiohttp.ClientSession() as session: async with session.post( self.guardrail_url, json=payload, headers=headers ) as response: result = await response.json() if self.debug: logger.debug(f"[EnkryptOutputGuardrail] Policy result: {result}") return result except Exception as e: logger.error(f"[EnkryptOutputGuardrail] Policy check error: {e}") return {"error": str(e)} async def _check_relevancy(self, question: str, answer: str) -> Dict[str, Any]: """Check relevancy using Enkrypt API.""" try: payload = {"question": question, "llm_answer": answer} headers = { "apikey": self.api_key, "Content-Type": "application/json", } if self.debug: logger.debug("[EnkryptOutputGuardrail] Checking relevancy") async with aiohttp.ClientSession() as session: async with session.post( self.relevancy_url, json=payload, headers=headers ) as response: result = await response.json() if self.debug: logger.debug(f"[EnkryptOutputGuardrail] Relevancy result: {result}") return result except Exception as e: logger.error(f"[EnkryptOutputGuardrail] Relevancy check error: {e}") return {"error": str(e), "score": 1.0} # Default to passing async def _check_adherence(self, context: str, answer: str) -> Dict[str, Any]: """Check adherence using Enkrypt API.""" try: payload = {"context": context, "llm_answer": answer} headers = { "apikey": self.api_key, "Content-Type": "application/json", } if self.debug: logger.debug("[EnkryptOutputGuardrail] Checking adherence") async with aiohttp.ClientSession() as session: async with session.post( self.adherence_url, json=payload, headers=headers ) as response: result = await response.json() if self.debug: logger.debug(f"[EnkryptOutputGuardrail] Adherence result: {result}") return result except Exception as e: logger.error(f"[EnkryptOutputGuardrail] Adherence check error: {e}") return {"error": str(e), "score": 1.0} # Default to passing async def _check_hallucination( self, request: str, response: str, context: str = "" ) -> Dict[str, Any]: """Check hallucination using Enkrypt API.""" try: payload = { "request_text": request, "response_text": response, "context": context, } headers = { "apikey": self.api_key, "Content-Type": "application/json", } if self.debug: logger.debug("[EnkryptOutputGuardrail] Checking hallucination") async with aiohttp.ClientSession() as session: async with session.post( self.hallucination_url, json=payload, headers=headers ) as response: result = await response.json() if self.debug: logger.debug(f"[EnkryptOutputGuardrail] Hallucination result: {result}") return result except Exception as e: logger.error(f"[EnkryptOutputGuardrail] Hallucination check error: {e}") return {"error": str(e), "has_hallucination": False} # Default to passing def get_supported_detectors(self) -> List[ViolationType]: """Get supported violation types for output.""" return [ ViolationType.PII, ViolationType.POLICY_VIOLATION, ViolationType.RELEVANCY_FAILURE, ViolationType.ADHERENCE_FAILURE, ViolationType.HALLUCINATION, ViolationType.TOXIC_CONTENT, ViolationType.NSFW_CONTENT, ] def _map_violation_type(self, enkrypt_type: str) -> ViolationType: """Map Enkrypt violation types to standard ViolationType enum.""" mapping = { "pii": ViolationType.PII, "policy_violation": ViolationType.POLICY_VIOLATION, "relevancy": ViolationType.RELEVANCY_FAILURE, "adherence": ViolationType.ADHERENCE_FAILURE, "toxicity": ViolationType.TOXIC_CONTENT, "nsfw": ViolationType.NSFW_CONTENT, "hallucination": ViolationType.HALLUCINATION, } return mapping.get(enkrypt_type, ViolationType.CUSTOM) class EnkryptPIIHandler: """ Enkrypt implementation of PIIHandler. This class is fully self-contained and makes direct API calls to Enkrypt. """ def __init__(self, api_key: str, base_url: str): self.api_key = api_key self.base_url = base_url self.pii_url = f"{base_url}/guardrails/pii" async def detect_pii(self, content: str) -> List[GuardrailViolation]: """Detect PII using Enkrypt.""" try: # Use the redact endpoint to detect PII payload = {"text": content, "mode": "request", "key": "null"} headers = { "apikey": self.api_key, "Content-Type": "application/json", } async with aiohttp.ClientSession() as session: async with session.post( self.pii_url, json=payload, headers=headers ) as response: result = await response.json() violations = [] # If text was modified, PII was detected if result.get("text") != content: violations.append( GuardrailViolation( violation_type=ViolationType.PII, severity=0.8, message="PII detected in content", action=GuardrailAction.MODIFY, metadata={ "original_length": len(content), "redacted_length": len(result.get("text", "")), }, ) ) return violations except Exception as e: logger.error(f"[EnkryptPIIHandler] PII detection error: {e}") return [] async def redact_pii(self, content: str) -> tuple[str, Dict[str, Any]]: """Redact PII using Enkrypt.""" try: payload = {"text": content, "mode": "request", "key": "null"} headers = { "apikey": self.api_key, "Content-Type": "application/json", } async with aiohttp.ClientSession() as session: async with session.post( self.pii_url, json=payload, headers=headers ) as response: result = await response.json() redacted_text = result.get("text", content) pii_key = result.get("key", "") return redacted_text, {"key": pii_key} except Exception as e: logger.error(f"[EnkryptPIIHandler] PII redaction error: {e}") return content, {} async def restore_pii(self, content: str, pii_mapping: Dict[str, Any]) -> str: """Restore PII using Enkrypt.""" try: pii_key = pii_mapping.get("key", "") if not pii_key: return content payload = {"text": content, "mode": "response", "key": pii_key} headers = { "apikey": self.api_key, "Content-Type": "application/json", } async with aiohttp.ClientSession() as session: async with session.post( self.pii_url, json=payload, headers=headers ) as response: result = await response.json() return result.get("text", content) except Exception as e: logger.error(f"[EnkryptPIIHandler] PII restoration error: {e}") return content class EnkryptServerRegistrationGuardrail: """ Handles server and tool registration validation using Enkrypt batch API. This class validates servers and tools during discovery to prevent malicious or unsafe MCP servers/tools from being registered. """ # Server validation detector config SERVER_DETECTORS: ClassVar[Dict[str, Any]] = { "injection_attack": {"enabled": True}, "policy_violation": { "enabled": True, # "coc_policy_name": "Safe Server Registration Policy", "policy_text": "Allow only safe servers to be registered for this MCP server and find any malicious servers to be blocked", "need_explanation": True, }, "keyword_detector": { "enabled": True, "banned_keywords": [ "exec", "shell", "eval", "sudo", "rm -rf", "delete_all", # Secret files "mcp.json", "claude_desktop_config.json", "enkrypt_mcp_config.json", ".env", ], }, "toxicity": {"enabled": True}, "nsfw": {"enabled": True}, "topic_detector": {"enabled": False, "topic": []}, "pii": {"enabled": False, "entities": []}, # "system_prompt": {"enabled": False}, # "copyright_ip": {"enabled": False}, "bias": {"enabled": True}, "sponge_attack": {"enabled": True}, } # Tool validation detector config TOOL_DETECTORS: ClassVar[Dict[str, Any]] = { "injection_attack": {"enabled": True}, "policy_violation": { "enabled": True, # "coc_policy_name": "Safe Tool Registration Policy", "policy_text": "Allow only safe tools to be registered for this MCP server and find any malicious tools to be blocked", "need_explanation": True, }, "keyword_detector": { "enabled": True, "banned_keywords": [ "exec", "shell", "eval", # "system", # "command", # "execute", "run_code", # "delete", # "drop", # "truncate", # "remove", "destroy", "wipe", # "format", "kill", "terminate", "exploit", "hack", "crack", "bypass", "override", # "privilege", "escalate", # "root", # "admin", "sudo", "chmod", "chown", # Secret files "mcp.json", "claude_desktop_config.json", "enkrypt_mcp_config.json", ".env", ], }, "toxicity": {"enabled": True}, "nsfw": {"enabled": True}, "topic_detector": { "enabled": False, "topic": [], }, "pii": {"enabled": False, "entities": []}, # "system_prompt": {"enabled": True}, # "copyright_ip": {"enabled": False}, "bias": {"enabled": True}, "sponge_attack": {"enabled": True}, } def __init__(self, api_key: str, base_url: str, config: Dict[str, Any] = None): import sys self.api_key = api_key self.base_url = base_url self.config = config or {} self.batch_url = f"{base_url}/guardrails/batch/detect" # Check both "debug" field and "enkrypt_log_level" for DEBUG self.debug = ( self.config.get("debug", False) or self.config.get("enkrypt_log_level", "").upper() == "DEBUG" ) logger.info( f"[EnkryptServerRegistrationGuardrail] Initialized with debug={self.debug}" ) logger.info( f"[EnkryptServerRegistrationGuardrail] Config keys: {list(self.config.keys())}" ) logger.info( f"[EnkryptServerRegistrationGuardrail] enkrypt_log_level={self.config.get('enkrypt_log_level')}" ) # Get custom detectors from config if provided registration_config = self.config.get("registration_validation", {}) if registration_config.get("custom_detectors"): self.SERVER_DETECTORS = registration_config.get( "server_detectors", self.SERVER_DETECTORS ) self.TOOL_DETECTORS = registration_config.get( "tool_detectors", self.TOOL_DETECTORS ) async def validate_server( self, request: ServerRegistrationRequest ) -> GuardrailResponse: """Validate server registration.""" start_time = time.time() try: # Build text from server metadata server_text = f"MCP Server: {request.server_name}" if request.server_description: server_text += f" - {request.server_description}" # Command triggers injection attacks. So only checking with name and description # if request.server_command: # server_text += f" | Command: {request.server_command}" if self.debug: logger.debug( f"[EnkryptServerRegistration] Validating server: {request.server_name}" ) logger.debug(f"[EnkryptServerRegistration] Text: {server_text}") # Call Enkrypt batch API response = await self._call_batch_api( texts=[server_text], detectors=self.SERVER_DETECTORS ) if self.debug: logger.debug( f"[EnkryptServerRegistration] Guardrail Response: {response}" ) # Analyze response result = response[0] violations = [] # Check each detector if result["summary"].get("injection_attack", 0) == 1: violations.append( GuardrailViolation( violation_type=ViolationType.INJECTION_ATTACK, severity=1.0, message="Injection attack detected in server metadata", action=GuardrailAction.BLOCK, metadata=result["details"].get("injection_attack", {}), ) ) if result["summary"].get("policy_violation", 0) == 1: policy_details = result["details"].get("policy_violation", {}) explanation = policy_details.get( "explanation", "Policy violation detected" ) violations.append( GuardrailViolation( violation_type=ViolationType.POLICY_VIOLATION, severity=1.0, message=f"Policy violation: {explanation}", action=GuardrailAction.BLOCK, metadata=policy_details, ) ) if result["summary"].get("toxicity", 0) == 1: violations.append( GuardrailViolation( violation_type=ViolationType.TOXIC_CONTENT, severity=0.8, message="Toxic content detected in server description", action=GuardrailAction.BLOCK, metadata=result["details"].get("toxicity", {}), ) ) if result["summary"].get("nsfw", 0) == 1: violations.append( GuardrailViolation( violation_type=ViolationType.NSFW_CONTENT, severity=0.8, message="NSFW content detected in server description", action=GuardrailAction.BLOCK, metadata=result["details"].get("nsfw", {}), ) ) if result["summary"].get("bias", 0) == 1: violations.append( GuardrailViolation( violation_type=ViolationType.BIAS, severity=0.8, message="Bias detected in server description", ) ) if result["summary"].get("sponge_attack", 0) == 1: violations.append( GuardrailViolation( violation_type=ViolationType.SPONGE_ATTACK, severity=0.8, message="Sponge attack detected in server description", ) ) if result["summary"].get("keyword_detector", 0) == 1: violations.append( GuardrailViolation( violation_type=ViolationType.KEYWORD_VIOLATION, severity=0.8, message="Keyword violation detected in server description", ) ) if result["summary"].get("pii", 0) == 1: violations.append( GuardrailViolation( violation_type=ViolationType.PII, severity=0.8, message="PII detected in server description", ) ) if result["summary"].get("topic_detector", 0) == 1: violations.append( GuardrailViolation( violation_type=ViolationType.TOPIC_DETECTOR, severity=0.8, message="Topic detector detected in server description", ) ) is_safe = len(violations) == 0 processing_time = (time.time() - start_time) * 1000 if self.debug: logger.debug( f"[EnkryptServerRegistration] Server validation result: {'SAFE' if is_safe else 'BLOCKED'}" ) if not is_safe: logger.debug( f"[EnkryptServerRegistration] Violations: {[v.message for v in violations]}" ) return GuardrailResponse( is_safe=is_safe, action=GuardrailAction.ALLOW if is_safe else GuardrailAction.BLOCK, violations=violations, processing_time_ms=processing_time, metadata={ "server_name": request.server_name, "detection_details": result, }, ) except Exception as e: logger.error(f"[EnkryptServerRegistration] Error validating server: {e}") # If guardrails enabled and auth failed, fail-closed to avoid approving unsafe cfg_enabled = bool(self.config.get("enkrypt_guardrails_enabled", True)) if cfg_enabled and ("UNAUTHORIZED:" in str(e)): return GuardrailResponse( is_safe=False, action=GuardrailAction.BLOCK, violations=[ GuardrailViolation( violation_type=ViolationType.POLICY_VIOLATION, severity=1.0, message="Guardrail authorization failed", action=GuardrailAction.BLOCK, metadata={"error": str(e)}, ) ], metadata={"error": str(e)}, ) # Otherwise fail-open return GuardrailResponse( is_safe=True, action=GuardrailAction.ALLOW, violations=[], metadata={"error": str(e)}, ) async def validate_tools( self, request: ToolRegistrationRequest ) -> GuardrailResponse: """Validate and filter tools.""" start_time = time.time() try: # Build texts array from tool descriptions texts = [] for tool in request.tools: # Handle both dict and Tool object if isinstance(tool, dict): tool_name = tool.get("name", "unknown") tool_desc = tool.get("description", "") annotations = tool.get("annotations", {}) else: # Tool object (has attributes) tool_name = getattr(tool, "name", "unknown") tool_desc = getattr(tool, "description", "") annotations = getattr(tool, "annotations", {}) or {} tool_text = f"Tool: {tool_name}" if tool_desc: tool_text += f" - {tool_desc}" # Add annotations if isinstance(annotations, dict): if annotations.get("destructiveHint"): tool_text += " [DESTRUCTIVE]" if annotations.get("readOnlyHint"): tool_text += " [READ-ONLY]" texts.append(tool_text) if self.debug: logger.debug( f"[EnkryptToolRegistration] Validating {len(texts)} tools for {request.server_name}" ) # Call Enkrypt batch API response = await self._call_batch_api( texts=texts, detectors=self.TOOL_DETECTORS ) if self.debug: logger.debug( f"[EnkryptToolRegistration] Guardrail Response: {response}" ) logger.debug(f"[EnkryptToolRegistration] Guardrail Response: {response}") # Analyze results safe_tools = [] blocked_tools = [] all_violations = [] for i, (tool, result) in enumerate(zip(request.tools, response)): # Check all enabled detectors tool_violations = [] if result["summary"].get("injection_attack", 0) == 1: tool_violations.append("injection attack detected") if result["summary"].get("policy_violation", 0) == 1: explanation = ( result["details"] .get("policy_violation", {}) .get("explanation", "policy violation") ) tool_violations.append(explanation) if result["summary"].get("toxicity", 0) == 1: tool_violations.append("toxic content") if result["summary"].get("nsfw", 0) == 1: tool_violations.append("NSFW content") if result["summary"].get("topic_detector", 0) == 1: tool_violations.append("dangerous topic detected") if result["summary"].get("bias", 0) == 1: tool_violations.append("bias detected") if result["summary"].get("sponge_attack", 0) == 1: tool_violations.append("sponge attack detected") if result["summary"].get("keyword_detector", 0) == 1: tool_violations.append("keyword violation detected") if result["summary"].get("pii", 0) == 1: tool_violations.append("PII detected") # Get tool name for reporting if isinstance(tool, dict): tool_name = tool.get("name", "unknown") else: tool_name = getattr(tool, "name", "unknown") if len(tool_violations) == 0: safe_tools.append(tool) else: blocked_tool_info = { "name": tool_name, "reasons": tool_violations, "detection_details": result, } blocked_tools.append(blocked_tool_info) # Create violation objects for reason in tool_violations: all_violations.append( GuardrailViolation( violation_type=ViolationType.POLICY_VIOLATION, severity=1.0, message=f"Blocked tool '{tool_name}': {reason}", action=GuardrailAction.BLOCK, metadata={"tool": tool_name, "reason": reason}, ) ) # Determine overall safety based on mode if request.validation_mode == "block_all": # Block all if any tool is unsafe is_safe = len(blocked_tools) == 0 else: # Filter mode: allow but filter unsafe tools is_safe = True processing_time = (time.time() - start_time) * 1000 if self.debug: logger.debug("[EnkryptToolRegistration] Validation complete:") logger.debug(f" - Total tools: {len(request.tools)}") logger.debug(f" - Safe tools: {len(safe_tools)}") logger.debug(f" - Blocked tools: {len(blocked_tools)}") if blocked_tools: logger.debug(f" - Blocked: {[t['name'] for t in blocked_tools]}") return GuardrailResponse( is_safe=is_safe, action=GuardrailAction.ALLOW if is_safe else GuardrailAction.BLOCK, violations=all_violations, processing_time_ms=processing_time, metadata={ "server_name": request.server_name, "total_tools": len(request.tools), "safe_tools_count": len(safe_tools), "blocked_tools_count": len(blocked_tools), "blocked_tools": blocked_tools, "filtered_tools": safe_tools, "validation_mode": request.validation_mode, }, ) except Exception as e: logger.error(f"[EnkryptToolRegistration] Error validating tools: {e}") cfg_enabled = bool(self.config.get("enkrypt_guardrails_enabled", True)) # Handle timeout errors - fail closed to prevent tool registration from secure_mcp_gateway.exceptions import TimeoutError if ( isinstance(e, TimeoutError) or "GUARDRAIL_TIMEOUT:" in str(e) or "timed out" in str(e).lower() ): logger.error( f"[EnkryptToolRegistration] Guardrail timeout - blocking tool registration: {e}" ) # Extract timeout details if available timeout_duration = getattr(e, "timeout_duration", "unknown") timeout_type = getattr(e, "timeout_type", "guardrail") return GuardrailResponse( is_safe=False, action=GuardrailAction.BLOCK, violations=[ GuardrailViolation( violation_type=ViolationType.CUSTOM, severity=1.0, message=f"Guardrail validation timed out after {timeout_duration}s", action=GuardrailAction.BLOCK, metadata={ "error": str(e), "timeout": True, "timeout_duration": timeout_duration, "timeout_type": timeout_type, }, ) ], metadata={ "error": str(e), "timeout": True, "timeout_duration": timeout_duration, "timeout_type": timeout_type, }, ) if cfg_enabled and ("UNAUTHORIZED:" in str(e)): # Fail closed: do not approve tools when guardrails enabled but unauthorized return GuardrailResponse( is_safe=False, action=GuardrailAction.BLOCK, violations=[ GuardrailViolation( violation_type=ViolationType.POLICY_VIOLATION, severity=1.0, message="Guardrail authorization failed", action=GuardrailAction.BLOCK, metadata={"error": str(e)}, ) ], metadata={"error": str(e)}, ) # FAIL CLOSED for other errors # Log with standardized error handling from secure_mcp_gateway.error_handling import error_logger from secure_mcp_gateway.exceptions import ( ErrorCode, ErrorContext, create_guardrail_error, ) context = ErrorContext( operation="guardrail.tool_validation_error", request_id=getattr(self, "request_id", None), server_name=getattr(self, "server_name", None), ) error = create_guardrail_error( code=ErrorCode.GUARDRAIL_VALIDATION_ERROR, message=f"Tool validation failed (fail-closed): {e}", context=context, cause=e, ) error_logger.log_error(error) # Block all tools on validation error return GuardrailResponse( is_safe=False, action=GuardrailAction.BLOCK, violations=[ GuardrailViolation( violation_type=ViolationType.CUSTOM, severity=1.0, message=f"Guardrail validation error: {e}", action=GuardrailAction.BLOCK, metadata={"error": str(e)}, ) ], metadata={ "error": str(e), "filtered_tools": [], # Block all tools }, ) async def _call_batch_api( self, texts: List[str], detectors: Dict[str, Any] ) -> List[Dict[str, Any]]: """Call Enkrypt batch detection API.""" try: def _sanitize_for_json(value): """Recursively sanitize values so they can be JSON-serialized. - Ensure dict keys are strings (drop None keys) - Convert non-serializable simple types to strings """ if isinstance(value, dict): clean = {} for k, v in value.items(): if k is None: continue sk = str(k) clean[sk] = _sanitize_for_json(v) return clean if isinstance(value, list): return [_sanitize_for_json(v) for v in value] if isinstance(value, (str, int, float, bool)) or value is None: return value # Fallback to string return str(value) safe_texts = ["" if t is None else str(t) for t in (texts or [])] safe_detectors = _sanitize_for_json(detectors or {}) payload = {"texts": safe_texts, "detectors": safe_detectors} headers = { "apikey": str(self.api_key or ""), "Content-Type": "application/json", } if self.debug: logger.debug( f"[EnkryptBatchAPI] Calling batch API with {len(texts)} texts" ) # Get configurable timeout from TimeoutManager from secure_mcp_gateway.services.timeout import get_timeout_manager timeout_manager = get_timeout_manager() timeout_value = timeout_manager.get_timeout("guardrail") # Use TimeoutManager to execute the API call with proper timeout handling async def _make_api_call(): async with aiohttp.ClientSession() as session: async with session.post( self.batch_url, json=payload, headers=headers, # Remove aiohttp.ClientTimeout - let TimeoutManager handle timeout ) as response: if response.status != 200: error_text = await response.text() if response.status in (401, 403): # Mark explicitly as unauthorized so callers can fail-closed raise Exception( f"UNAUTHORIZED: API error {response.status}: {error_text}" ) raise Exception( f"API error {response.status}: {error_text}" ) result = await response.json() return result # Execute with TimeoutManager for proper metrics and escalation timeout_result = await timeout_manager.execute_with_timeout( _make_api_call, "guardrail", f"batch_api_{len(texts)}_texts" ) if not timeout_result.success: if timeout_result.error: raise timeout_result.error else: raise Exception("API call failed") result = timeout_result.result if self.debug: logger.debug("[EnkryptBatchAPI] Batch API response received") return result except Exception as e: logger.error(f"[EnkryptBatchAPI] Batch API call failed: {e}") # Use standardized error handling from secure_mcp_gateway.error_handling import error_logger from secure_mcp_gateway.exceptions import ( ErrorCode, ErrorContext, create_guardrail_error, ) # Create error context for proper tracing context = ErrorContext( operation="guardrail_batch_api", request_id=getattr(self, "request_id", None), server_name=getattr(self, "server_name", None), ) # Handle different error types with proper error codes if "UNAUTHORIZED:" in str(e): # Propagate unauthorized marker via exception so upper layers can block raise elif "timeout" in str(e).lower() or "timed out" in str(e).lower(): # Create standardized timeout error error = create_guardrail_error( code=ErrorCode.GUARDRAIL_TIMEOUT, message=f"Guardrail API call timed out: {e}", context=context, cause=e, ) error_logger.log_error(error) raise error else: # Create standardized API error error = create_guardrail_error( code=ErrorCode.GUARDRAIL_API_ERROR, message=f"Guardrail API call failed: {e}", context=context, cause=e, ) error_logger.log_error(error) # FAIL CLOSED: Raise error instead of returning safe default raise error class EnkryptGuardrailProvider(GuardrailProvider): """ Enkrypt AI guardrail provider implementation. This provider is fully self-contained with NO dependency on guardrail_service. All API calls are made directly from this provider. """ def __init__( self, api_key: str, base_url: str = "https://api.enkryptai.com", config: Dict[str, Any] = None, ): # Use provided credentials; if missing, fetch from full config as fallback self.api_key = api_key self.base_url = base_url self.config = config or {} # Initialize registration guardrail # If api_key not set or is literal 'null', fallback to config values if not self.api_key or str(self.api_key).lower() == "null": fetched_key, fetched_base = self._get_api_credentials() if fetched_key: self.api_key = fetched_key if fetched_base: self.base_url = fetched_base self.registration_guardrail = EnkryptServerRegistrationGuardrail( api_key=self.api_key, base_url=self.base_url, config=self.config ) def get_name(self) -> str: """Get provider name.""" return "enkrypt" def get_version(self) -> str: """Get provider version.""" return "2.0.0" def _get_api_credentials(self) -> tuple[str, str]: """Get API key and base URL from plugin configuration.""" import json from secure_mcp_gateway.utils import ( CONFIG_PATH, DOCKER_CONFIG_PATH, does_file_exist, is_docker, ) # Load the full configuration file directly is_running_in_docker = is_docker() picked_config_path = DOCKER_CONFIG_PATH if is_running_in_docker else CONFIG_PATH if does_file_exist(picked_config_path): with open(picked_config_path, encoding="utf-8") as f: full_config = json.load(f) else: # Fallback to self values if config file not found return self.api_key, self.base_url plugins_config = full_config.get("plugins", {}) guardrails_config = plugins_config.get("guardrails", {}).get("config", {}) auth_config = plugins_config.get("auth", {}).get("config", {}) api_key = guardrails_config.get( "api_key", auth_config.get("api_key", self.api_key) ) base_url = guardrails_config.get( "base_url", auth_config.get("base_url", self.base_url) ) return api_key, base_url def create_input_guardrail( self, config: Dict[str, Any] ) -> Optional[InputGuardrail]: """Create Enkrypt input guardrail.""" if not config.get("enabled", False): return None api_key, base_url = self._get_api_credentials() return EnkryptInputGuardrail(config, api_key, base_url) def create_output_guardrail( self, config: Dict[str, Any] ) -> Optional[OutputGuardrail]: """Create Enkrypt output guardrail.""" if not config.get("enabled", False): return None api_key, base_url = self._get_api_credentials() return EnkryptOutputGuardrail(config, api_key, base_url) def create_pii_handler(self, config: Dict[str, Any]) -> Optional[PIIHandler]: """Create Enkrypt PII handler.""" if config.get("pii_redaction", False): api_key, base_url = self._get_api_credentials() return EnkryptPIIHandler(api_key, base_url) return None def validate_config(self, config: Dict[str, Any]) -> bool: """Validate Enkrypt configuration.""" if config.get("enabled", False): # Policy name is required when enabled if not config.get("policy_name"): return False return True def get_required_config_keys(self) -> List[str]: """Get required config keys.""" return ["enabled", "policy_name"] async def validate_server_registration( self, request: ServerRegistrationRequest ) -> Optional[GuardrailResponse]: """Validate server registration using Enkrypt batch API.""" return await self.registration_guardrail.validate_server(request) async def validate_tool_registration( self, request: ToolRegistrationRequest ) -> Optional[GuardrailResponse]: """Validate tool registration using Enkrypt batch API.""" return await self.registration_guardrail.validate_tools(request) def get_metadata(self) -> Dict[str, Any]: """Get provider metadata.""" base_metadata = super().get_metadata() base_metadata.update( { "api_url": self.base_url, "supports_async": True, "supports_batch": True, "max_content_length": 100000, "supports_policy_detection": True, "supports_relevancy": True, "supports_adherence": True, "supports_hallucination": True, "supports_pii_redaction": True, "supports_registration_validation": True, "supports_server_validation": True, "supports_tool_validation": True, } ) return base_metadata

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/enkryptai/secure-mcp-gateway'

If you have feedback or need assistance with the MCP directory API, please join our Discord server