"""Logging utilities with automatic correlation ID injection.
Provides:
1. Custom Logger class that auto-injects correlation IDs
2. Global logger factory functions
3. @log_call decorator for function entry/exit/error tracking
"""
import asyncio
import functools
import logging
import time
from collections.abc import Callable
from src.utils.context import get_correlation_id
class CorrelationLogger(logging.Logger):
"""Custom Logger that automatically injects correlation IDs into all log messages.
This logger transparently adds correlation IDs without requiring changes to
existing logging calls. Once installed via set_logger_class(), all loggers
created with logging.getLogger() will automatically include correlation IDs.
Example:
>>> # Setup (typically in server.py)
>>> setup_correlation_logging()
>>>
>>> # Usage (anywhere in codebase)
>>> logger = logging.getLogger(__name__)
>>> logger.info("Processing workflow") # Automatically includes correlation ID
>>> # Output: [abc-123] Processing workflow
"""
def _log(self, level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1):
"""Override _log to inject correlation ID into message."""
correlation_id = get_correlation_id()
if correlation_id:
msg = f"[{correlation_id}] {msg}"
super()._log(level, msg, args, exc_info, extra, stack_info, stacklevel)
def setup_correlation_logging() -> None:
"""Configure logging to use CorrelationLogger globally.
Call this once during application startup (e.g., in server.py).
After calling this, ALL loggers created with logging.getLogger()
will automatically inject correlation IDs.
Example:
>>> # In server.py
>>> from src.utils.logging import setup_correlation_logging
>>> setup_correlation_logging()
>>>
>>> # Now all loggers automatically include correlation IDs
>>> logger = logging.getLogger("my_module")
>>> logger.info("Hello") # → "[abc-123] Hello"
"""
logging.setLoggerClass(CorrelationLogger)
def get_global_logger(name: str) -> logging.Logger:
"""Get a logger instance with global correlation support.
This is a convenience wrapper around logging.getLogger() that ensures
correlation logging is set up globally. If setup_correlation_logging()
hasn't been called yet, this will do it automatically.
Args:
name: Logger name (typically __name__)
Returns:
Logger instance with automatic correlation ID injection
Example:
>>> from src.utils.logging import get_global_logger
>>> logger = get_global_logger(__name__)
>>> logger.info("Processing") # → "[abc-123] Processing"
"""
# Ensure correlation logging is enabled
if logging.getLoggerClass() != CorrelationLogger:
setup_correlation_logging()
return logging.getLogger(name)
# Default logger instance for backward compatibility
logger = logging.getLogger("ComfyUI_MCP")
def log_call(
action_name: str | None = None,
level_name: str = "route",
log_level: str = "INFO",
log_params: bool = False,
sensitive_params: list[str] | None = None,
) -> Callable:
"""Decorator for automatic function logging with entry/exit/error tracking.
This decorator provides standardized logging for route functions, following
the aspect-oriented programming pattern used in dl_remuxed. It automatically
logs function entry, exit, execution time, and errors without cluttering
business logic.
Supports both synchronous and asynchronous functions.
Args:
action_name: Optional custom name for the action (defaults to function name)
level_name: Category for the log entry (e.g., "route", "client", "auth")
log_level: Python logging level ("DEBUG", "INFO", "WARNING", "ERROR")
log_params: If True, log function parameters on entry (default: False)
sensitive_params: List of parameter names to redact (e.g., ["api_key", "token"])
Returns:
Decorated function with automatic logging
Example:
>>> @log_call(action_name="queue_workflow", level_name="route")
... async def queue_workflow(auth, workflow):
... return await some_api_call()
Logs:
INFO: [route] queue_workflow - ENTER
INFO: [route] queue_workflow - SUCCESS (0.245s)
Example with parameter logging:
>>> @log_call(
... action_name="configure_webhook",
... level_name="tool",
... log_params=True,
... sensitive_params=["webhook_secret", "api_key"]
... )
... async def configure_webhook(url: str, webhook_secret: str):
... pass
Logs:
INFO: [tool] configure_webhook - ENTER params={'url': 'https://...', 'webhook_secret': '[REDACTED]'}
"""
def decorator(func: Callable) -> Callable:
# Use custom name or function name
name = action_name or func.__name__
# Get logging function based on level
log_fn = getattr(logger, log_level.lower(), logger.info)
# Convert sensitive params to set for faster lookup
sens_params = set(sensitive_params or [])
def sanitize_params(params: dict) -> dict:
"""Sanitize sensitive parameters by replacing values with [REDACTED]."""
if not sens_params:
return params
return {k: "[REDACTED]" if k in sens_params else v for k, v in params.items()}
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
"""Wrapper for async functions."""
start_time = time.time()
# Get correlation ID from context
correlation_id = get_correlation_id()
prefix = f"[{correlation_id}] " if correlation_id else ""
# Log entry with optional parameters
if log_params and kwargs:
sanitized = sanitize_params(kwargs)
log_fn(f"{prefix}[{level_name}] {name} - ENTER params={sanitized}")
else:
log_fn(f"{prefix}[{level_name}] {name} - ENTER")
try:
result = await func(*args, **kwargs)
elapsed = time.time() - start_time
log_fn(f"{prefix}[{level_name}] {name} - SUCCESS ({elapsed:.3f}s)")
return result
except Exception as e:
elapsed = time.time() - start_time
logger.error(
f"{prefix}[{level_name}] {name} - ERROR ({elapsed:.3f}s): {type(e).__name__}: {e}"
)
raise
@functools.wraps(func)
def sync_wrapper(*args, **kwargs):
"""Wrapper for sync functions."""
start_time = time.time()
# Get correlation ID from context
correlation_id = get_correlation_id()
prefix = f"[{correlation_id}] " if correlation_id else ""
# Log entry with optional parameters
if log_params and kwargs:
sanitized = sanitize_params(kwargs)
log_fn(f"{prefix}[{level_name}] {name} - ENTER params={sanitized}")
else:
log_fn(f"{prefix}[{level_name}] {name} - ENTER")
try:
result = func(*args, **kwargs)
elapsed = time.time() - start_time
log_fn(f"{prefix}[{level_name}] {name} - SUCCESS ({elapsed:.3f}s)")
return result
except Exception as e:
elapsed = time.time() - start_time
logger.error(
f"{prefix}[{level_name}] {name} - ERROR ({elapsed:.3f}s): {type(e).__name__}: {e}"
)
raise
# Return appropriate wrapper based on function type
if asyncio.iscoroutinefunction(func):
return async_wrapper
return sync_wrapper
return decorator