container.py•16.5 kB
"""Dependency injection container for modular architecture.
This module provides a ServiceContainer that manages service lifetimes and
dependencies, enabling loose coupling, easier testing, and better architecture.
"""
import inspect
import types
from collections.abc import Callable
from enum import Enum
from typing import Any, TypeVar, Union, get_args, get_origin
from .logger import get_logger
logger = get_logger("container")
T = TypeVar("T")
class ServiceLifetime(Enum):
"""Service lifetime management options."""
SINGLETON = "singleton" # One instance for entire application lifetime
SCOPED = "scoped" # New instance for each operation/request
TRANSIENT = "transient" # New instance every time requested
class ServiceRegistration:
"""Registration information for a service."""
def __init__(
self,
interface: type,
implementation: type | Callable,
lifetime: ServiceLifetime,
factory: Callable | None = None,
):
self.interface = interface
self.implementation = implementation
self.lifetime = lifetime
self.factory = factory
class ServiceContainer:
"""Dependency injection container with lifetime management.
This container provides:
- Service registration with different lifetimes
- Automatic dependency resolution
- Constructor injection
- Singleton management
- Scoped service creation
- Factory pattern support
"""
def __init__(self) -> None:
"""Initialize the service container."""
self._registrations: dict[type, ServiceRegistration] = {}
self._singletons: dict[type, Any] = {}
self._factories: dict[type, Callable[..., Any]] = {}
self._resolution_stack: set[type] = set()
logger.debug("ServiceContainer initialized")
def register_singleton(self, interface: type[T], implementation: type[T]) -> None:
"""Register a service as a singleton.
Singleton services are created once and reused for the entire
application lifetime. Ideal for stateless services like repositories.
Args:
interface: The interface/protocol type
implementation: The concrete implementation class
"""
self._registrations[interface] = ServiceRegistration(
interface=interface,
implementation=implementation,
lifetime=ServiceLifetime.SINGLETON,
)
logger.debug(
f"Registered singleton: {interface.__name__} -> {implementation.__name__}"
)
def register_scoped(self, interface: type[T], implementation: type[T]) -> None:
"""Register a service as scoped.
Scoped services are created once per operation/request and disposed
when the operation completes. Ideal for services that maintain state.
Args:
interface: The interface/protocol type
implementation: The concrete implementation class
"""
self._registrations[interface] = ServiceRegistration(
interface=interface,
implementation=implementation,
lifetime=ServiceLifetime.SCOPED,
)
logger.debug(
f"Registered scoped: {interface.__name__} -> {implementation.__name__}"
)
def register_transient(self, interface: type[T], implementation: type[T]) -> None:
"""Register a service as transient.
Transient services are created new every time they are requested.
Ideal for lightweight services without shared state.
Args:
interface: The interface/protocol type
implementation: The concrete implementation class
"""
self._registrations[interface] = ServiceRegistration(
interface=interface,
implementation=implementation,
lifetime=ServiceLifetime.TRANSIENT,
)
logger.debug(
f"Registered transient: {interface.__name__} -> {implementation.__name__}"
)
def register_factory(self, interface: type[T], factory: Callable[..., T]) -> None:
"""Register a service with a custom factory function.
Factory registration allows complex initialization logic and
conditional service creation.
Args:
interface: The interface/protocol type
factory: Factory function that creates the service
"""
self._factories[interface] = factory
self._registrations[interface] = ServiceRegistration(
interface=interface,
implementation=factory,
lifetime=ServiceLifetime.SCOPED, # Factories are scoped by default
factory=factory,
)
logger.debug(f"Registered factory: {interface.__name__}")
def register_instance(self, interface: type[T], instance: T) -> None:
"""Register a pre-created instance as a singleton.
Useful for registering configuration objects or external dependencies.
Args:
interface: The interface/protocol type
instance: Pre-created instance to register
"""
self._singletons[interface] = instance
self._registrations[interface] = ServiceRegistration(
interface=interface,
implementation=type(instance),
lifetime=ServiceLifetime.SINGLETON,
)
logger.debug(f"Registered instance: {interface.__name__}")
def resolve(self, interface: type[T]) -> T:
"""Resolve a service by its interface.
Args:
interface: The interface/protocol type to resolve
Returns:
An instance of the service implementing the interface
Raises:
ValueError: If the service is not registered
RuntimeError: If circular dependency is detected
"""
# Check for circular dependencies
if interface in self._resolution_stack:
stack_str = " -> ".join(
getattr(cls, "__name__", str(cls)) for cls in self._resolution_stack
)
interface_name = getattr(interface, "__name__", str(interface))
raise RuntimeError(
f"Circular dependency detected: {stack_str} -> {interface_name}"
)
# Check if service is registered
if interface not in self._registrations:
interface_name = getattr(interface, "__name__", str(interface))
raise ValueError(f"Service {interface_name} is not registered")
registration = self._registrations[interface]
# Handle singleton lifetime
if registration.lifetime == ServiceLifetime.SINGLETON:
if interface in self._singletons:
return self._singletons[interface]
# Create singleton instance
instance = self._create_instance(interface, registration)
self._singletons[interface] = instance
return instance
# Handle scoped and transient lifetimes
return self._create_instance(interface, registration)
def _create_instance(
self, interface: type[T], registration: ServiceRegistration
) -> T:
"""Create an instance of a service.
Args:
interface: The interface type being resolved
registration: Service registration information
Returns:
New instance of the service
"""
self._resolution_stack.add(interface)
try:
# Handle factory registration
if registration.factory:
return self._create_from_factory(registration.factory)
# Handle class registration
implementation = registration.implementation
# At this point, implementation should be a Type since factory is None
return self._create_from_class(implementation) # type: ignore[arg-type]
finally:
self._resolution_stack.discard(interface)
def _create_from_factory(self, factory: Callable) -> Any:
"""Create instance using a factory function.
Args:
factory: Factory function to call
Returns:
Instance created by the factory
"""
# Get factory signature for dependency injection
signature = inspect.signature(factory)
kwargs = {}
for param_name, param in signature.parameters.items():
if param.annotation != inspect.Parameter.empty:
try:
dependency = self._resolve_optional_dependency(param.annotation)
kwargs[param_name] = dependency
except ValueError as e:
# Skip primitive types - they should use their default values
if "requires a default value" in str(e):
logger.debug(
f"Skipping primitive type parameter {param_name} in factory, using default value"
)
continue
else:
raise
return factory(**kwargs)
def _is_optional_type(self, annotation: Any) -> tuple[bool, type | None]:
"""Check if an annotation is an optional type (Union[T, None] or T | None).
Args:
annotation: Type annotation to check
Returns:
Tuple of (is_optional, non_none_type)
"""
# Handle Union types (both typing.Union and | syntax)
origin = get_origin(annotation)
# Check for both typing.Union and types.UnionType (Python 3.10+)
if origin is Union or isinstance(annotation, types.UnionType):
args = get_args(annotation)
# Check if this is Optional[T] (Union[T, None])
if len(args) == 2 and type(None) in args:
non_none_type = next(arg for arg in args if arg is not type(None))
return True, non_none_type
return False, None
def _is_primitive_type(self, annotation: Any) -> bool:
"""Check if an annotation is a primitive type that should not be dependency-injected.
Args:
annotation: Type annotation to check
Returns:
True if this is a primitive type (int, str, float, bool, etc.)
"""
primitive_types = {int, str, float, bool, bytes, type(None), type, object}
return annotation in primitive_types
def _resolve_optional_dependency(self, param_annotation: Any) -> Any:
"""Resolve an optional dependency, returning None if not registered.
Args:
param_annotation: Parameter type annotation
Returns:
Resolved instance or None if optional dependency is not registered
Raises:
ValueError: If primitive type dependency cannot be resolved (should use default value)
"""
# Check if this is a primitive type that shouldn't be dependency-injected
if self._is_primitive_type(param_annotation):
raise ValueError(
f"Primitive type {param_annotation} requires a default value"
)
is_optional, non_none_type = self._is_optional_type(param_annotation)
if is_optional and non_none_type:
# Check if the non-None type is a primitive
if self._is_primitive_type(non_none_type):
raise ValueError(
f"Optional primitive type {non_none_type} requires a default value"
)
# Check if the non-None type is registered
if self.is_registered(non_none_type):
return self.resolve(non_none_type)
else:
# Optional dependency not registered - return None
logger.debug(
f"Optional dependency {non_none_type} not registered, using None"
)
return None
else:
# Not an optional type, resolve normally
return self.resolve(param_annotation)
def _create_from_class(self, implementation: type) -> Any:
"""Create instance using constructor injection.
Args:
implementation: Class to instantiate
Returns:
New instance with dependencies injected
"""
# Get constructor signature
signature = inspect.signature(implementation)
kwargs = {}
for param_name, param in signature.parameters.items():
if param_name == "self":
continue
if param.annotation != inspect.Parameter.empty:
try:
# Resolve dependency (handles optional types)
dependency = self._resolve_optional_dependency(param.annotation)
kwargs[param_name] = dependency
except ValueError as e:
# Skip primitive types - they should use their default values
if "requires a default value" in str(e):
logger.debug(
f"Skipping primitive type parameter {param_name} in {implementation.__name__}, using default value"
)
continue
else:
raise
elif param.default != inspect.Parameter.empty:
# Use default value if no annotation
continue
else:
# Parameter has no type annotation and no default value
logger.warning(
f"Parameter {param_name} in {implementation.__name__} has no type annotation and no default value"
)
# For testing purposes, we'll skip this parameter, but in production this should be an error
# raise ValueError(f"Cannot resolve parameter {param_name} in {implementation.__name__}: no type annotation or default value")
return implementation(**kwargs)
def is_registered(self, interface: type) -> bool:
"""Check if a service is registered.
Args:
interface: The interface type to check
Returns:
True if the service is registered
"""
return interface in self._registrations
def get_registration(self, interface: type) -> ServiceRegistration | None:
"""Get registration information for a service.
Args:
interface: The interface type to query
Returns:
ServiceRegistration if found, None otherwise
"""
return self._registrations.get(interface)
def clear_singletons(self) -> None:
"""Clear all singleton instances.
Useful for testing or when you need to reset the container state.
"""
self._singletons.clear()
logger.debug("All singleton instances cleared")
def get_registered_services(self) -> dict[type, ServiceRegistration]:
"""Get all registered services.
Returns:
Dictionary mapping interface types to their registrations
"""
return self._registrations.copy()
async def dispose_async(self) -> None:
"""Dispose of resources asynchronously.
Calls dispose methods on services that implement IDisposable
or have async context manager support.
"""
disposed_count = 0
for instance in self._singletons.values():
if hasattr(instance, "__aexit__"):
try:
await instance.__aexit__(None, None, None)
disposed_count += 1
except Exception as e:
logger.error(
f"Error disposing service {type(instance).__name__}: {e}"
)
elif hasattr(instance, "dispose"):
try:
if inspect.iscoroutinefunction(instance.dispose):
await instance.dispose()
else:
instance.dispose()
disposed_count += 1
except Exception as e:
logger.error(
f"Error disposing service {type(instance).__name__}: {e}"
)
logger.debug(f"Disposed {disposed_count} services")
self._singletons.clear()
def __str__(self) -> str:
"""String representation of the container."""
return f"ServiceContainer(services={len(self._registrations)}, singletons={len(self._singletons)})"