"""Common utilities for MCP Gateway."""
import json
import os
import secrets
import socket
import string
import sys
import threading
import time
from typing import Any, Dict, Union
from urllib.parse import urlparse
from secure_mcp_gateway.consts import (
CONFIG_PATH,
DEFAULT_COMMON_CONFIG,
DOCKER_CONFIG_PATH,
EXAMPLE_CONFIG_NAME,
EXAMPLE_CONFIG_PATH,
)
# Lazy import to avoid circular imports
#
# We expose a single, centralized logger for the whole application via a
# lazy accessor. Most modules should import and use `utils.logger`, which
# routes to the active telemetry provider's logger when telemetry is
# enabled. This centralizes formatting, context, and export behavior.
#
# IMPORTANT: Telemetry modules themselves (e.g. providers under
# `plugins/telemetry`) MUST NOT import from `utils` to fetch this logger,
# because `utils` depends on telemetry during initialization. Those
# modules should instead create a local module logger with
# `logging.getLogger(...)` to avoid circular imports during bootstrap.
_logger_cache = None
def get_logger():
"""Return the active application logger lazily.
This defers importing the telemetry config/manager until first use,
preventing circular imports during process startup. If telemetry is
disabled or unavailable, this returns None and `LazyLogger` will
no-op calls.
"""
global _logger_cache
if _logger_cache is None:
try:
from secure_mcp_gateway.plugins.telemetry import (
get_telemetry_config_manager,
)
telemetry_manager = get_telemetry_config_manager()
_logger_cache = telemetry_manager.get_logger()
# print("[utils] Logger initialized successfully", file=sys.stderr)
except Exception as e:
# If telemetry is not available, return None
# print(f"[utils] Logger initialization failed: {e}", file=sys.stderr)
_logger_cache = None
return _logger_cache
# For backward compatibility, expose logger as a module-level variable
class LazyLogger:
"""Lazy logger wrapper used by application modules.
Accessing any logging method (e.g., `.info`, `.debug`) forwards the
call to the real telemetry-backed logger when available. Otherwise it
becomes a safe no-op. This allows importing `logger` from `utils`
everywhere without eagerly initializing telemetry.
"""
def __getattr__(self, name):
logger = get_logger()
if logger:
return getattr(logger, name)
# No-op if logger not available
# print(
# f"[utils] LazyLogger: No logger available for method {name}",
# file=sys.stderr,
# )
return lambda *args, **kwargs: None
# Central application logger for non-telemetry modules.
#
# Usage guidance:
# - In most modules, prefer: `from secure_mcp_gateway.utils import logger`
# - In telemetry provider/config modules, prefer a local
# `logging.getLogger("enkrypt.telemetry")` to avoid importing `utils`
# (which depends on telemetry initialization) and creating a circular
# dependency.
logger = LazyLogger()
from secure_mcp_gateway.version import __version__
# Get debug log level (lazy-loaded to avoid circular imports)
def _get_debug_log_level():
return get_common_config().get("enkrypt_log_level", "INFO").lower() == "debug"
# Use a property-like approach to avoid circular imports
class _DebugLevel:
def __bool__(self):
return _get_debug_log_level()
IS_DEBUG_LOG_LEVEL = _DebugLevel()
# NOTE:
# This module is imported very early in the gateway startup sequence by multiple
# subsystems. At that time, the telemetry provider (which owns the logger
# configuration) may not yet be initialized. As a result, calls to acquire a
# logger can return None, and the LazyLogger will no-op. To ensure critical
# bootstrap diagnostics are visible, we mirror key messages to stderr via
# print() in addition to logger calls. Once telemetry is initialized, logger
# messages will flow through the configured provider as usual.
# Initialize logger for this module
# print(
# f"[utils] Initializing Enkrypt Secure MCP Gateway Common Utilities Module v{__version__}",
# file=sys.stderr,
# )
# logger.info(
# f"[utils] Initializing Enkrypt Secure MCP Gateway Common Utilities Module v{__version__}"
# )
IS_TELEMETRY_ENABLED = None
# --------------------------------------------------------------------------
# Also redefined funcations in telemetry.py to avoid circular imports
# If logic changes, please make changes in both files
# --------------------------------------------------------------------------
def get_file_from_root(file_name):
"""
Get the absolute path of a file from the root directory (two levels up from current script)
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, "..", ".."))
return os.path.join(root_dir, file_name)
def get_absolute_path(file_name):
"""
Get the absolute path of a file
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
return os.path.join(current_dir, file_name)
def does_file_exist(file_name_or_path, is_absolute_path=None):
"""
Check if a file exists in the current directory
"""
if is_absolute_path is None:
# Try to determine if it's an absolute path
is_absolute_path = os.path.isabs(file_name_or_path)
if is_absolute_path:
return os.path.exists(file_name_or_path)
else:
return os.path.exists(get_absolute_path(file_name_or_path))
def is_docker():
"""
Check if the code is running inside a Docker container.
"""
# Check for Docker environment markers
docker_env_indicators = ["/.dockerenv", "/run/.containerenv"]
for indicator in docker_env_indicators:
if os.path.exists(indicator):
return True
# Check cgroup for any containerization system entries
container_identifiers = ["docker", "kubepods", "containerd", "lxc"]
try:
with open("/proc/1/cgroup", encoding="utf-8") as f:
for line in f:
if any(keyword in line for keyword in container_identifiers):
return True
except FileNotFoundError:
# /proc/1/cgroup doesn't exist, which is common outside of Linux
pass
return False
# Config cache with file modification time tracking for hot-reload support
# Thread-safe implementation for async/concurrent access
_config_cache = {}
_config_mtime = 0
_config_path_cached = None
_config_lock = threading.RLock()
def get_common_config(print_debug=False):
"""
Get the common configuration for the Enkrypt Secure MCP Gateway.
Uses file modification time to detect config changes and reload automatically.
This enables hot-reload when config files are updated (e.g., in Docker volumes).
Thread-safe for concurrent access.
"""
global _config_cache, _config_mtime, _config_path_cached
# NOTE: Using sys_print here will cause a circular import between get_common_config, is_telemetry_enabled, and sys_print functions.
# So we are using print instead.
if print_debug:
logger.debug(f"[utils] config_path: {CONFIG_PATH}")
logger.debug(f"[utils] docker_config_path: {DOCKER_CONFIG_PATH}")
logger.debug(f"[utils] example_config_path: {EXAMPLE_CONFIG_PATH}")
is_running_in_docker = is_docker()
picked_config_path = DOCKER_CONFIG_PATH if is_running_in_docker else CONFIG_PATH
with _config_lock:
# Check if config file has been modified since last read
if does_file_exist(picked_config_path):
try:
current_mtime = os.path.getmtime(picked_config_path)
except OSError as e:
logger.warning(
f"[utils] Error getting mtime for {picked_config_path}: {e}"
)
# Return cached config if available, otherwise use defaults
if _config_cache:
return _config_cache
return {**DEFAULT_COMMON_CONFIG, "plugins": {}}
# Return cached config if file hasn't changed
if (
_config_cache
and _config_path_cached == picked_config_path
and current_mtime == _config_mtime
):
return _config_cache
# File changed or first load - reload config
try:
print(f"[utils] Loading {picked_config_path} file...", file=sys.stderr)
logger.info(f"[utils] Loading {picked_config_path} file...")
with open(picked_config_path, encoding="utf-8") as f:
config = json.load(f)
_config_mtime = current_mtime
_config_path_cached = picked_config_path
except (OSError, json.JSONDecodeError) as e:
logger.error(
f"[utils] Error loading config from {picked_config_path}: {e}"
)
# Return cached config if available, otherwise use defaults
if _config_cache:
return _config_cache
return {**DEFAULT_COMMON_CONFIG, "plugins": {}}
else:
logger.info("[utils] No config file found. Loading example config.")
if does_file_exist(EXAMPLE_CONFIG_PATH):
if print_debug:
logger.debug(f"[utils] Loading {EXAMPLE_CONFIG_NAME} file...")
try:
with open(EXAMPLE_CONFIG_PATH, encoding="utf-8") as f:
config = json.load(f)
except (OSError, json.JSONDecodeError) as e:
logger.error(f"[utils] Error loading example config: {e}")
config = {}
else:
logger.info(
"[utils] Example config file not found. Using default common config."
)
config = {}
if print_debug and config:
logger.debug(f"[utils] config: {config}")
common_config = config.get("common_mcp_gateway_config", {})
plugins_config = config.get("plugins", {})
# Merge with defaults to ensure all required fields exist
_config_cache = {
**DEFAULT_COMMON_CONFIG,
**common_config,
"plugins": plugins_config,
}
return _config_cache
def clear_config_cache():
"""Clear the config cache to force reload on next get_common_config() call."""
global _config_cache, _config_mtime, _config_path_cached
with _config_lock:
_config_cache = {}
_config_mtime = 0
_config_path_cached = None
def is_telemetry_enabled():
"""
Check if telemetry is enabled
"""
global IS_TELEMETRY_ENABLED
if IS_TELEMETRY_ENABLED:
return True
elif IS_TELEMETRY_ENABLED is not None:
return False
config = get_common_config()
telemetry_plugin_config = config.get("plugins", {}).get("telemetry", {})
telemetry_config = telemetry_plugin_config.get("config", {})
if not telemetry_config.get("enabled", False):
IS_TELEMETRY_ENABLED = False
return False
endpoint = telemetry_config.get("url", "http://localhost:4317")
try:
parsed_url = urlparse(endpoint)
hostname = parsed_url.hostname
port = parsed_url.port
if not hostname or not port:
logger.error(f"[utils] Invalid OTLP endpoint URL: {endpoint}")
IS_TELEMETRY_ENABLED = False
return False
# Get configurable timeout from TimeoutManager
from secure_mcp_gateway.services.timeout import get_timeout_manager
timeout_manager = get_timeout_manager()
timeout_value = timeout_manager.get_timeout("connectivity")
with socket.create_connection((hostname, port), timeout=timeout_value):
IS_TELEMETRY_ENABLED = True
return True
except (OSError, AttributeError, TypeError, ValueError) as e:
logger.error(
f"[utils] Telemetry is enabled in config, but endpoint {endpoint} is not accessible. So, disabling telemetry. Error: {e}"
)
IS_TELEMETRY_ENABLED = False
return False
def generate_custom_id():
"""
Generate a unique identifier consisting of 34 random characters followed by current timestamp.
Returns:
str: A string in format '{random_chars}_{timestamp_ms}' that can be used as a unique identifier
"""
try:
# Generate 34 random characters (letters + digits)
charset = string.ascii_letters + string.digits
random_part = "".join(secrets.choice(charset) for _ in range(34))
# Get current epoch time in milliseconds
timestamp_ms = int(time.time() * 1000)
return f"{random_part}_{timestamp_ms}"
except Exception as e:
logger.error(f"[utils] Error generating custom ID: {e}")
# Fallback to a simpler ID if there's an error
return f"fallback_{int(time.time())}"
def sys_print(*args, **kwargs):
"""
Print a message using the logger system.
Args:
*args: Arguments to log
**kwargs: Keyword arguments including:
- is_error (bool): If True, use logger.error
- is_debug (bool): If True, use logger.debug
"""
is_error = kwargs.pop("is_error", False)
is_debug = kwargs.pop("is_debug", False)
# Using try/except to avoid any logging errors blocking the flow for edge cases
try:
if args:
# Join all arguments into a single message
message = " ".join(str(arg) for arg in args)
# Route to appropriate logger method
if is_error:
logger.error(message)
elif is_debug:
logger.debug(message)
else:
logger.info(message)
except Exception as e:
# Fallback to print if logger fails
print(f"[utils] Error logging using sys_print: {e}", file=sys.stderr)
pass
def mask_key(key):
"""
Masks the last 4 characters of the key.
"""
if not key or len(key) < 4:
return "****"
return "****" + key[-4:]
def build_log_extra(ctx, custom_id=None, server_name=None, error=None, **kwargs):
"""Build structured log extras. Tolerates missing/invalid ctx.
Falls back to 'not_provided' values if ctx is not an MCP Context or
if credentials/config cannot be resolved.
"""
project_id = "not_provided"
user_id = "not_provided"
project_name = "not_provided"
email = "not_provided"
mcp_config_id = "not_provided"
try:
# Only attempt auth lookups when ctx looks like an MCP Context
has_ctx_attrs = hasattr(ctx, "request_context") or hasattr(ctx, "__dict__")
if has_ctx_attrs:
from secure_mcp_gateway.plugins.auth import get_auth_config_manager
auth_manager = get_auth_config_manager()
credentials = auth_manager.get_gateway_credentials(ctx)
gateway_key = credentials.get("gateway_key")
project_id = credentials.get("project_id", project_id)
user_id = credentials.get("user_id", user_id)
if gateway_key:
try:
import asyncio
# Check if we're already in an async context
try:
# Try to get the current event loop
loop = asyncio.get_running_loop()
# If we get here, we're in an async context, skip the call
# to avoid creating unawaited coroutines
pass
except RuntimeError:
# No event loop running, safe to use asyncio.run()
try:
gateway_config = (
asyncio.run(
auth_manager.get_local_mcp_config(
gateway_key, project_id, user_id
)
)
or {}
)
project_name = gateway_config.get(
"project_name", project_name
)
email = gateway_config.get("email", email)
mcp_config_id = gateway_config.get(
"mcp_config_id", mcp_config_id
)
except Exception:
# If anything fails, just use defaults
pass
except Exception:
# If anything fails, just use defaults
pass
except Exception:
# Swallow errors and use defaults to avoid breaking logging
pass
# Filter out None values from kwargs
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
return {
"custom_id": custom_id or "",
"server_name": server_name or "",
"project_id": project_id or "",
"project_name": project_name or "",
"user_id": user_id or "",
"email": email or "",
"mcp_config_id": mcp_config_id or "",
"error": error or "",
**filtered_kwargs,
}
def mask_server_config_sensitive_data(server_info):
"""
Masks sensitive data in server configuration before returning to client.
Args:
server_info (dict): Server configuration dictionary
Returns:
dict: Server configuration with sensitive data masked
"""
if not server_info:
return server_info
# Create a deep copy to avoid modifying the original
import copy
masked_server_info = copy.deepcopy(server_info)
# Mask environment variables in config
if "config" in masked_server_info and "env" in masked_server_info["config"]:
masked_server_info["config"]["env"] = mask_sensitive_env_vars(
masked_server_info["config"]["env"]
)
return masked_server_info
def mask_sensitive_env_vars(env_vars):
"""
Masks sensitive environment variables that may contain tokens, keys, or secrets.
Args:
env_vars (dict): Dictionary of environment variables
Returns:
dict: Environment variables with sensitive values masked
"""
if not env_vars:
return env_vars
sensitive_keys = [
"token",
"key",
"secret",
"password",
"pass",
"auth",
"credential",
"api_key",
"access_token",
"refresh_token",
"bearer",
"jwt",
"github_token",
"github_key",
"gitlab_token",
"bitbucket_token",
"aws_key",
"aws_secret",
"azure_key",
"gcp_key",
"database_url",
"connection_string",
"uri",
"url",
]
masked_env = {}
for key, value in env_vars.items():
key_lower = key.lower()
is_sensitive = any(
sensitive_key in key_lower for sensitive_key in sensitive_keys
)
if is_sensitive and value:
# Mask the value, showing only first 4 and last 4 characters
if len(value) <= 8:
masked_env[key] = "****"
else:
masked_env[key] = value[:4] + "****" + value[-4:]
else:
masked_env[key] = value
return masked_env
def get_server_info_by_name(gateway_config, server_name):
"""
Retrieves server configuration by server name from gateway config.
Args:
gateway_config (dict): Gateway/user's configuration containing server details
server_name (str): Name of the server to look up
Returns:
dict: Server configuration if found, None otherwise
"""
if IS_DEBUG_LOG_LEVEL:
logger.debug(f"[get_server_info_by_name] Getting server info for {server_name}")
mcp_config = gateway_config.get("mcp_config", [])
if IS_DEBUG_LOG_LEVEL:
# Mask sensitive data in debug logs
masked_mcp_config = []
for server in mcp_config:
masked_server = server.copy()
if "config" in masked_server and "env" in masked_server["config"]:
masked_server["config"] = masked_server["config"].copy()
masked_server["config"]["env"] = mask_sensitive_env_vars(
masked_server["config"]["env"]
)
masked_mcp_config.append(masked_server)
logger.debug(f"[get_server_info_by_name] mcp_config: {masked_mcp_config}")
return next((s for s in mcp_config if s.get("server_name") == server_name), None)
def mask_sensitive_headers(
headers: Union[Dict[str, str], Dict[str, Any]],
) -> Dict[str, str]:
"""
Mask sensitive information in HTTP headers for logging purposes.
Args:
headers: Dictionary of HTTP headers
Returns:
Dictionary with sensitive headers masked
"""
if not headers:
return {}
# Define sensitive header patterns (case-insensitive)
sensitive_patterns = [
# Authentication headers
"authorization",
"auth",
"bearer",
"token",
"apikey",
"api-key",
"api_key",
"x-api-key",
"x-auth-token",
"x-access-token",
"x-auth",
"x-token",
"x-enkrypt-api-key",
"x-enkrypt-gateway-key",
# Security headers
"cookie",
"set-cookie",
"x-csrf-token",
"x-csrf",
"csrf-token",
"x-requested-with",
"x-forwarded-for",
"x-real-ip",
# Sensitive data headers
"password",
"passwd",
"pwd",
"secret",
"private",
"key",
"session",
"sessionid",
"session-id",
"sess",
# Custom sensitive headers
"x-session",
"x-user",
"x-tenant",
"x-org",
"x-organization",
"x-client",
"x-device",
"x-device-id",
"x-deviceid",
# OAuth and JWT
"oauth",
"jwt",
"access-token",
"refresh-token",
"id-token",
"x-oauth",
"x-jwt",
"x-access",
"x-refresh",
]
masked_headers = {}
for key, value in headers.items():
key_lower = key.lower()
# Check if this header should be masked
should_mask = any(pattern in key_lower for pattern in sensitive_patterns)
if should_mask:
# Mask the value but preserve the structure
if isinstance(value, str) and len(value) > 0:
if len(value) <= 4:
masked_headers[key] = "***"
else:
# Show first 2 and last 2 characters for longer values
masked_headers[key] = f"{value[:2]}***{value[-2:]}"
else:
masked_headers[key] = "***"
else:
# Keep non-sensitive headers as-is
masked_headers[key] = value
return masked_headers
def mask_sensitive_data(
data: Dict[str, Any], sensitive_keys: list = None
) -> Dict[str, Any]:
"""
Recursively mask sensitive information in a dictionary.
Args:
data: Dictionary to mask
sensitive_keys: List of keys to mask (defaults to common sensitive keys)
Returns:
Dictionary with sensitive values masked
"""
if sensitive_keys is None:
sensitive_keys = [
"password",
"passwd",
"pwd",
"secret",
"private",
"key",
"token",
"apikey",
"api_key",
"api-key",
"auth",
"authorization",
"bearer",
"session",
"sessionid",
"session-id",
"cookie",
"csrf",
"oauth",
"jwt",
"access-token",
"refresh-token",
"id-token",
"x-api-key",
"x-auth-token",
"x-access-token",
"x-csrf-token",
"x-enkrypt-api-key",
"x-enkrypt-gateway-key",
]
if not isinstance(data, dict):
return data
masked_data = {}
for key, value in data.items():
key_lower = key.lower()
# Check if this key should be masked
should_mask = any(pattern in key_lower for pattern in sensitive_keys)
if should_mask:
if isinstance(value, str) and len(value) > 0:
if len(value) <= 4:
masked_data[key] = "***"
else:
masked_data[key] = f"{value[:2]}***{value[-2:]}"
else:
masked_data[key] = "***"
elif isinstance(value, dict):
# Recursively mask nested dictionaries
masked_data[key] = mask_sensitive_data(value, sensitive_keys)
elif isinstance(value, list):
# Mask items in lists
masked_data[key] = [
mask_sensitive_data(item, sensitive_keys)
if isinstance(item, dict)
else item
for item in value
]
else:
masked_data[key] = value
return masked_data