"""Example guardrail providers."""
from typing import Any, Dict, List, Optional
import httpx
from secure_mcp_gateway.plugins.guardrails.base import (
GuardrailAction,
GuardrailProvider,
GuardrailRequest,
GuardrailResponse,
GuardrailViolation,
InputGuardrail,
OutputGuardrail,
ViolationType,
)
# ============================================================================
# OpenAI Moderation API Provider
# ============================================================================
class OpenAIInputGuardrail:
"""OpenAI Moderation API input guardrail implementation."""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.api_key = config.get("api_key", "")
self.threshold = config.get("threshold", 0.7)
self.block_categories = config.get(
"block_categories", ["hate", "violence", "sexual"]
)
async def validate(self, request: GuardrailRequest) -> GuardrailResponse:
"""Validate using OpenAI Moderation API."""
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.openai.com/v1/moderations",
headers={"Authorization": f"Bearer {self.api_key}"},
json={"input": request.content},
)
result = response.json()
moderation_result = result["results"][0]
violations = []
is_safe = True
# Check categories
categories = moderation_result.get("categories", {})
category_scores = moderation_result.get("category_scores", {})
for category, flagged in categories.items():
if flagged and category in self.block_categories:
score = category_scores.get(category, 0.0)
if score >= self.threshold:
is_safe = False
violations.append(
GuardrailViolation(
violation_type=self._map_category_to_violation(
category
),
severity=score,
message=f"Content flagged for {category}",
action=GuardrailAction.BLOCK,
metadata={"category": category, "score": score},
)
)
return GuardrailResponse(
is_safe=is_safe,
action=GuardrailAction.ALLOW if is_safe else GuardrailAction.BLOCK,
violations=violations,
metadata={"provider": "openai-moderation"},
)
def get_supported_detectors(self) -> List[ViolationType]:
"""Get supported detectors."""
return [
ViolationType.TOXIC_CONTENT,
ViolationType.NSFW_CONTENT,
ViolationType.CUSTOM,
]
def _map_category_to_violation(self, category: str) -> ViolationType:
"""Map OpenAI categories to violation types."""
mapping = {
"hate": ViolationType.TOXIC_CONTENT,
"violence": ViolationType.TOXIC_CONTENT,
"sexual": ViolationType.NSFW_CONTENT,
"self-harm": ViolationType.TOXIC_CONTENT,
}
return mapping.get(category, ViolationType.CUSTOM)
class OpenAIGuardrailProvider(GuardrailProvider):
"""OpenAI Moderation API provider."""
def __init__(self, api_key: str):
self.api_key = api_key
def get_name(self) -> str:
return "openai-moderation"
def get_version(self) -> str:
return "1.0.0"
def create_input_guardrail(
self, config: Dict[str, Any]
) -> Optional[InputGuardrail]:
"""Create OpenAI input guardrail."""
if not config.get("enabled", False):
return None
config["api_key"] = self.api_key
return OpenAIInputGuardrail(config)
def create_output_guardrail(
self, config: Dict[str, Any]
) -> Optional[OutputGuardrail]:
"""OpenAI Moderation can also work on outputs."""
if not config.get("enabled", False):
return None
# Reuse the same implementation for output
config["api_key"] = self.api_key
class OpenAIOutputGuardrail:
def __init__(self, input_guardrail):
self._input_guardrail = input_guardrail
async def validate(
self, response_content: str, original_request: GuardrailRequest
) -> GuardrailResponse:
# Create a new request with the response content
return await self._input_guardrail.validate(
GuardrailRequest(content=response_content)
)
def get_supported_detectors(self) -> List[ViolationType]:
return self._input_guardrail.get_supported_detectors()
return OpenAIOutputGuardrail(OpenAIInputGuardrail(config))
def validate_config(self, config: Dict[str, Any]) -> bool:
"""Validate configuration."""
if config.get("enabled", False):
if not self.api_key:
return False
return True
def get_required_config_keys(self) -> List[str]:
"""Get required config keys."""
return ["enabled"]
# ============================================================================
# AWS Comprehend Provider (Example)
# ============================================================================
class AWSComprehendInputGuardrail:
"""AWS Comprehend sentiment/PII detection guardrail."""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.region = config.get("region", "us-east-1")
self.detect_pii = config.get("detect_pii", True)
self.detect_sentiment = config.get("detect_sentiment", True)
self.negative_threshold = config.get("negative_threshold", 0.8)
# In real implementation, initialize boto3 client here
# import boto3
# self.client = boto3.client('comprehend', region_name=self.region)
async def validate(self, request: GuardrailRequest) -> GuardrailResponse:
"""Validate using AWS Comprehend."""
violations = []
is_safe = True
# Example implementation (would need boto3 in real use)
# This is pseudo-code to show the structure
# Detect PII
if self.detect_pii:
# pii_result = await self._detect_pii_entities(request.content)
# if pii_result['Entities']:
# violations.append(GuardrailViolation(...))
pass
# Detect sentiment
if self.detect_sentiment:
# sentiment_result = await self._detect_sentiment(request.content)
# if sentiment_result['Sentiment'] == 'NEGATIVE':
# score = sentiment_result['SentimentScore']['Negative']
# if score >= self.negative_threshold:
# violations.append(GuardrailViolation(...))
pass
return GuardrailResponse(
is_safe=is_safe,
action=GuardrailAction.ALLOW if is_safe else GuardrailAction.WARN,
violations=violations,
metadata={"provider": "aws-comprehend", "region": self.region},
)
def get_supported_detectors(self) -> List[ViolationType]:
"""Get supported detectors."""
return [ViolationType.PII, ViolationType.TOXIC_CONTENT]
class AWSComprehendProvider(GuardrailProvider):
"""AWS Comprehend provider."""
def __init__(self, region: str = "us-east-1"):
self.region = region
def get_name(self) -> str:
return "aws-comprehend"
def get_version(self) -> str:
return "1.0.0"
def create_input_guardrail(
self, config: Dict[str, Any]
) -> Optional[InputGuardrail]:
"""Create AWS Comprehend input guardrail."""
if not config.get("enabled", False):
return None
config["region"] = self.region
return AWSComprehendInputGuardrail(config)
def create_output_guardrail(
self, config: Dict[str, Any]
) -> Optional[OutputGuardrail]:
"""AWS Comprehend can work on outputs too."""
# Similar to input, can be reused
return None
def get_required_config_keys(self) -> List[str]:
return ["enabled"]
# ============================================================================
# Custom Regex/Keyword Based Provider (Simple Example)
# ============================================================================
class CustomKeywordGuardrail:
"""Simple keyword-based guardrail."""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.blocked_keywords = config.get("blocked_keywords", [])
self.case_sensitive = config.get("case_sensitive", False)
async def validate(self, request: GuardrailRequest) -> GuardrailResponse:
"""Validate using keyword matching."""
violations = []
content = request.content
if not self.case_sensitive:
content = content.lower()
blocked_keywords = [kw.lower() for kw in self.blocked_keywords]
else:
blocked_keywords = self.blocked_keywords
for keyword in blocked_keywords:
if keyword in content:
violations.append(
GuardrailViolation(
violation_type=ViolationType.KEYWORD_VIOLATION,
severity=1.0,
message=f"Blocked keyword detected: {keyword}",
action=GuardrailAction.BLOCK,
metadata={"keyword": keyword},
)
)
is_safe = len(violations) == 0
return GuardrailResponse(
is_safe=is_safe,
action=GuardrailAction.ALLOW if is_safe else GuardrailAction.BLOCK,
violations=violations,
metadata={"provider": "custom-keyword"},
)
def get_supported_detectors(self) -> List[ViolationType]:
"""Get supported detectors."""
return [ViolationType.KEYWORD_VIOLATION]
class CustomKeywordProvider(GuardrailProvider):
"""Custom keyword-based provider."""
def __init__(self, blocked_keywords: List[str]):
self.blocked_keywords = blocked_keywords
def get_name(self) -> str:
return "custom-keyword"
def get_version(self) -> str:
return "1.0.0"
def create_input_guardrail(
self, config: Dict[str, Any]
) -> Optional[InputGuardrail]:
"""Create keyword input guardrail."""
if not config.get("enabled", False):
return None
config["blocked_keywords"] = self.blocked_keywords
return CustomKeywordGuardrail(config)
def create_output_guardrail(
self, config: Dict[str, Any]
) -> Optional[OutputGuardrail]:
"""Create keyword output guardrail."""
if not config.get("enabled", False):
return None
config["blocked_keywords"] = self.blocked_keywords
class CustomKeywordOutputGuardrail:
def __init__(self, input_guardrail):
self._input_guardrail = input_guardrail
async def validate(
self, response_content: str, original_request: GuardrailRequest
) -> GuardrailResponse:
return await self._input_guardrail.validate(
GuardrailRequest(content=response_content)
)
def get_supported_detectors(self) -> List[ViolationType]:
return self._input_guardrail.get_supported_detectors()
return CustomKeywordOutputGuardrail(CustomKeywordGuardrail(config))
def get_required_config_keys(self) -> List[str]:
return ["enabled", "blocked_keywords"]
# ============================================================================
# Composite Provider (Combines Multiple Providers)
# ============================================================================
class CompositeGuardrail:
"""Combines multiple guardrails with AND/OR logic."""
def __init__(self, guardrails: List[InputGuardrail], logic: str = "OR"):
"""
Args:
guardrails: List of guardrails to combine
logic: "OR" (any violation blocks) or "AND" (all must violate to block)
"""
self.guardrails = guardrails
self.logic = logic.upper()
async def validate(self, request: GuardrailRequest) -> GuardrailResponse:
"""Validate using all guardrails."""
all_violations = []
all_safe = []
for guardrail in self.guardrails:
result = await guardrail.validate(request)
all_violations.extend(result.violations)
all_safe.append(result.is_safe)
# Apply logic
is_safe = all(all_safe) if self.logic == "OR" else any(all_safe)
return GuardrailResponse(
is_safe=is_safe,
action=GuardrailAction.ALLOW if is_safe else GuardrailAction.BLOCK,
violations=all_violations,
metadata={"provider": "composite", "logic": self.logic},
)
def get_supported_detectors(self) -> List[ViolationType]:
"""Get all supported detectors from all guardrails."""
all_detectors = set()
for guardrail in self.guardrails:
all_detectors.update(guardrail.get_supported_detectors())
return list(all_detectors)
class CompositeGuardrailProvider(GuardrailProvider):
"""
Composite provider that combines multiple providers.
This is useful when you want to use multiple guardrail services together.
For example: Use Enkrypt for policy violations AND OpenAI for moderation.
"""
def __init__(self, providers: List[GuardrailProvider], logic: str = "OR"):
self.providers = providers
self.logic = logic
def get_name(self) -> str:
provider_names = "_".join([p.get_name() for p in self.providers])
return f"composite_{provider_names}"
def get_version(self) -> str:
return "1.0.0"
def create_input_guardrail(
self, config: Dict[str, Any]
) -> Optional[InputGuardrail]:
"""Create composite input guardrail."""
if not config.get("enabled", False):
return None
guardrails = []
for provider in self.providers:
guardrail = provider.create_input_guardrail(config)
if guardrail:
guardrails.append(guardrail)
if not guardrails:
return None
return CompositeGuardrail(guardrails, self.logic)
def create_output_guardrail(
self, config: Dict[str, Any]
) -> Optional[OutputGuardrail]:
"""Create composite output guardrail."""
if not config.get("enabled", False):
return None
guardrails = []
for provider in self.providers:
guardrail = provider.create_output_guardrail(config)
if guardrail:
guardrails.append(guardrail)
if not guardrails:
return None
class CompositeOutputGuardrail:
def __init__(self, guardrails_list, logic):
self.guardrails = guardrails_list
self.logic = logic
async def validate(
self, response_content: str, original_request: GuardrailRequest
) -> GuardrailResponse:
all_violations = []
all_safe = []
for guardrail in self.guardrails:
result = await guardrail.validate(
response_content, original_request
)
all_violations.extend(result.violations)
all_safe.append(result.is_safe)
# Apply logic
is_safe = all(all_safe) if self.logic == "OR" else any(all_safe)
return GuardrailResponse(
is_safe=is_safe,
action=GuardrailAction.ALLOW if is_safe else GuardrailAction.BLOCK,
violations=all_violations,
metadata={"provider": "composite", "logic": self.logic},
)
def get_supported_detectors(self) -> List[ViolationType]:
all_detectors = set()
for guardrail in self.guardrails:
all_detectors.update(guardrail.get_supported_detectors())
return list(all_detectors)
return CompositeOutputGuardrail(guardrails, self.logic)
def get_required_config_keys(self) -> List[str]:
return ["enabled"]
def get_metadata(self) -> Dict[str, Any]:
"""Get metadata including all provider metadata."""
base = super().get_metadata()
base["providers"] = [p.get_metadata() for p in self.providers]
base["logic"] = self.logic
return base