"""
Configuration management using Pydantic Settings.
"""
import json
import os
from functools import lru_cache
from pathlib import Path
from typing import Literal
from pydantic import Field, field_validator, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class SentinelHubConfig(BaseSettings):
"""Sentinel Hub API configuration."""
model_config = SettingsConfigDict(env_prefix="SENTINEL_HUB_")
client_id: str = Field(default="")
client_secret: str = Field(default="")
instance_id: str = Field(default="")
@model_validator(mode='before')
@classmethod
def load_from_json_env(cls, values):
sentinel_json = os.environ.get("SENTINELHUB")
if sentinel_json:
try:
data = json.loads(sentinel_json)
values["client_id"] = data.get("client_id", "")
values["client_secret"] = data.get("client_secret", "")
except (json.JSONDecodeError, TypeError):
pass
return values
@property
def is_configured(self) -> bool:
return bool(self.client_id and self.client_secret)
class DatabaseConfig(BaseSettings):
model_config = SettingsConfigDict(env_prefix="DATABASE_")
url: str = Field(default="postgresql+asyncpg://postgres:password@localhost:5432/geosight")
pool_size: int = Field(default=10)
max_overflow: int = Field(default=20)
echo: bool = Field(default=False)
class RedisConfig(BaseSettings):
model_config = SettingsConfigDict(env_prefix="REDIS_")
url: str = Field(default="redis://localhost:6379/0")
cache_ttl_seconds: int = Field(default=3600)
cache_max_size_mb: int = Field(default=512)
class S3Config(BaseSettings):
model_config = SettingsConfigDict(env_prefix="S3_")
endpoint_url: str = Field(default="http://localhost:9000")
access_key: str = Field(default="minioadmin")
secret_key: str = Field(default="minioadmin")
bucket_name: str = Field(default="geosight-data")
region: str = Field(default="us-east-1")
class ModelConfig(BaseSettings):
model_config = SettingsConfigDict(env_prefix="MODEL_")
weights_dir: Path = Field(default=Path("./models/weights"))
inference_device: Literal["cpu", "cuda", "mps"] = Field(default="cpu")
precision: Literal["fp32", "fp16", "int8"] = Field(default="fp32")
land_cover_model: str = Field(default="resnet50")
change_detection_model: str = Field(default="siamese_unet")
object_detection_model: str = Field(default="yolov8m")
@field_validator("weights_dir", mode="before")
@classmethod
def validate_weights_dir(cls, v):
return Path(v) if isinstance(v, str) else v
class ServerConfig(BaseSettings):
model_config = SettingsConfigDict(env_prefix="SERVER_")
mode: Literal["stdio", "http", "sse"] = Field(default="stdio")
host: str = Field(default="0.0.0.0")
port: int = Field(default=8000)
class CeleryConfig(BaseSettings):
model_config = SettingsConfigDict(env_prefix="CELERY_")
broker_url: str = Field(default="redis://localhost:6379/1")
result_backend: str = Field(default="redis://localhost:6379/2")
task_timeout: int = Field(default=600)
class FeatureFlagsConfig(BaseSettings):
model_config = SettingsConfigDict(env_prefix="ENABLE_")
object_detection: bool = Field(default=True)
change_detection: bool = Field(default=True)
report_generation: bool = Field(default=True)
caching: bool = Field(default=True)
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
environment: Literal["development", "staging", "production"] = Field(default="development")
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(default="INFO")
api_key: str = Field(default="")
cors_origins: str = Field(default="http://localhost:3000,http://localhost:8501")
rate_limit_requests_per_minute: int = Field(default=60)
rate_limit_burst: int = Field(default=10)
prometheus_enabled: bool = Field(default=True)
prometheus_port: int = Field(default=9090)
sentry_dsn: str = Field(default="")
sentinel_hub: SentinelHubConfig = Field(default_factory=SentinelHubConfig)
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
redis: RedisConfig = Field(default_factory=RedisConfig)
s3: S3Config = Field(default_factory=S3Config)
model: ModelConfig = Field(default_factory=ModelConfig)
server: ServerConfig = Field(default_factory=ServerConfig)
celery: CeleryConfig = Field(default_factory=CeleryConfig)
features: FeatureFlagsConfig = Field(default_factory=FeatureFlagsConfig)
@property
def cors_origins_list(self) -> list[str]:
return [origin.strip() for origin in self.cors_origins.split(",")]
@property
def is_production(self) -> bool:
return self.environment == "production"
@property
def is_development(self) -> bool:
return self.environment == "development"
@lru_cache
def get_settings() -> Settings:
return Settings()
settings = get_settings()