settings.py•11.8 kB
#!/usr/bin/env python3
"""
Configuration Settings
Centralized configuration management for AnyDocs-MCP.
"""
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from dataclasses import dataclass, field
from enum import Enum
import yaml
from pydantic import Field, validator
from pydantic_settings import BaseSettings
from pydantic import BaseModel, Field, ConfigDict
from typing import List
class LogLevel(str, Enum):
"""Logging levels."""
DEBUG = "DEBUG"
INFO = "INFO"
WARNING = "WARNING"
ERROR = "ERROR"
CRITICAL = "CRITICAL"
class AuthMethod(str, Enum):
"""Authentication methods."""
API_KEY = "api_key"
OAUTH2 = "oauth2"
JWT = "jwt"
NONE = "none"
class DatabaseType(str, Enum):
"""Database types."""
SQLITE = "sqlite"
POSTGRESQL = "postgresql"
MYSQL = "mysql"
@dataclass
class DocSourceConfig:
"""Configuration for a documentation source."""
name: str
adapter_type: str # gitbook, notion, confluence, etc.
config: Dict[str, Any] = field(default_factory=dict)
enabled: bool = True
priority: int = 1
cache_ttl: int = 3600 # Cache TTL in seconds
def __post_init__(self):
"""Validate configuration after initialization."""
if not self.name:
raise ValueError("Doc source name cannot be empty")
if not self.adapter_type:
raise ValueError("Adapter type cannot be empty")
@dataclass
class ServerConfig:
"""MCP server configuration."""
name: str = "AnyDocs-MCP"
version: str = "0.1.0"
description: str = "Transform documentation into MCP-compatible server"
host: str = "localhost"
port: int = 8000
debug: bool = False
log_level: LogLevel = LogLevel.INFO
max_workers: int = 4
timeout: int = 30
class AuthConfig(BaseModel):
methods: List[str] = Field(['jwt'], description='认证方式')
jwt_secret: str = Field('default_secret', alias='JWT_SECRET')
token_expire_minutes: int = Field(30, alias='TOKEN_EXPIRE_MINUTES')
model_config = ConfigDict(
env_prefix='AUTH_',
populate_by_name=True,
extra='ignore'
)
class Settings(BaseSettings):
auth: AuthConfig = Field(default_factory=AuthConfig)
"""Application settings with environment variable support."""
# Application Info
app_name: str = Field(default="AnyDocs-MCP", env="APP_NAME")
app_version: str = Field(default="0.1.0", env="APP_VERSION")
app_description: str = Field(default="Transform documentation into MCP-compatible server", env="APP_DESCRIPTION")
# Server Configuration
server_host: str = Field(default="localhost", env="SERVER_HOST")
server_port: int = Field(default=8000, env="SERVER_PORT")
server_debug: bool = Field(default=False, env="SERVER_DEBUG")
server_reload: bool = Field(default=False, env="SERVER_RELOAD")
# Logging
log_level: LogLevel = Field(default=LogLevel.INFO, env="LOG_LEVEL")
log_file: Optional[str] = Field(default=None, env="LOG_FILE")
log_format: str = Field(default="%(asctime)s - %(name)s - %(levelname)s - %(message)s", env="LOG_FORMAT")
# Database
database_type: DatabaseType = Field(default=DatabaseType.SQLITE, env="DATABASE_TYPE")
database_url: str = Field(default="sqlite:///./anydocs_mcp.db", env="DATABASE_URL")
database_echo: bool = Field(default=False, env="DATABASE_ECHO")
# Authentication
auth_method: AuthMethod = Field(default=AuthMethod.API_KEY, env="AUTH_METHOD")
secret_key: str = Field(default="your-secret-key-change-this", env="SECRET_KEY")
access_token_expire_minutes: int = Field(default=30, env="ACCESS_TOKEN_EXPIRE_MINUTES")
# API Keys (for external services)
openai_api_key: Optional[str] = Field(default=None, env="OPENAI_API_KEY")
anthropic_api_key: Optional[str] = Field(default=None, env="ANTHROPIC_API_KEY")
# Cache Configuration
cache_enabled: bool = Field(default=True, env="CACHE_ENABLED")
cache_ttl: int = Field(default=3600, env="CACHE_TTL") # 1 hour
cache_max_size: int = Field(default=1000, env="CACHE_MAX_SIZE")
# Rate Limiting
rate_limit_enabled: bool = Field(default=True, env="RATE_LIMIT_ENABLED")
rate_limit_requests: int = Field(default=100, env="RATE_LIMIT_REQUESTS")
rate_limit_window: int = Field(default=60, env="RATE_LIMIT_WINDOW") # seconds
# CORS
cors_enabled: bool = Field(default=True, env="CORS_ENABLED")
cors_origins: List[str] = Field(default=["*"], env="CORS_ORIGINS")
cors_methods: List[str] = Field(default=["GET", "POST", "PUT", "DELETE"], env="CORS_METHODS")
# File Upload
max_file_size: int = Field(default=10 * 1024 * 1024, env="MAX_FILE_SIZE") # 10MB
allowed_file_types: List[str] = Field(default=[".md", ".txt", ".json", ".yaml", ".yml"], env="ALLOWED_FILE_TYPES")
# Configuration File Path
config_file: Optional[str] = Field(default=None, env="CONFIG_FILE")
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
case_sensitive = False
@validator("server_port")
def validate_port(cls, v):
if not 1 <= v <= 65535:
raise ValueError("Port must be between 1 and 65535")
return v
@validator("secret_key")
def validate_secret_key(cls, v):
if len(v) < 32:
raise ValueError("Secret key must be at least 32 characters long")
return v
@validator("cors_origins", pre=True)
def parse_cors_origins(cls, v):
if isinstance(v, str):
return [origin.strip() for origin in v.split(",")]
return v
@validator("cors_methods", pre=True)
def parse_cors_methods(cls, v):
if isinstance(v, str):
return [method.strip().upper() for method in v.split(",")]
return v
@validator("allowed_file_types", pre=True)
def parse_allowed_file_types(cls, v):
if isinstance(v, str):
return [ext.strip() for ext in v.split(",")]
return v
def get_database_url(self) -> str:
"""Get the complete database URL."""
if self.database_type == DatabaseType.SQLITE:
# Ensure SQLite database directory exists
if self.database_url.startswith("sqlite:///"):
db_path = Path(self.database_url[10:]) # Remove 'sqlite:///'
db_path.parent.mkdir(parents=True, exist_ok=True)
return self.database_url
def is_development(self) -> bool:
"""Check if running in development mode."""
return self.server_debug or self.log_level == LogLevel.DEBUG
def is_production(self) -> bool:
"""Check if running in production mode."""
return not self.is_development()
def get_log_config(self) -> Dict[str, Any]:
"""Get logging configuration."""
config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": self.log_format,
},
"detailed": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(module)s:%(lineno)d - %(message)s",
},
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": self.log_level.value,
"formatter": "default",
"stream": "ext://sys.stdout",
},
},
"loggers": {
"anydocs_mcp": {
"level": self.log_level.value,
"handlers": ["console"],
"propagate": False,
},
"uvicorn": {
"level": "INFO",
"handlers": ["console"],
"propagate": False,
},
},
"root": {
"level": self.log_level.value,
"handlers": ["console"],
},
}
# Add file handler if log file is specified
if self.log_file:
log_path = Path(self.log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
config["handlers"]["file"] = {
"class": "logging.handlers.RotatingFileHandler",
"level": self.log_level.value,
"formatter": "detailed",
"filename": str(log_path),
"maxBytes": 10 * 1024 * 1024, # 10MB
"backupCount": 5,
}
# Add file handler to all loggers
for logger_config in config["loggers"].values():
logger_config["handlers"].append("file")
config["root"]["handlers"].append("file")
return config
def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]:
"""Load configuration from YAML file."""
config_path = Path(config_path)
if not config_path.exists():
raise FileNotFoundError(f"Configuration file not found: {config_path}")
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
return config or {}
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML configuration: {e}")
except Exception as e:
raise RuntimeError(f"Error loading configuration: {e}")
def save_yaml_config(config: Dict[str, Any], config_path: Union[str, Path]) -> None:
"""Save configuration to YAML file."""
config_path = Path(config_path)
config_path.parent.mkdir(parents=True, exist_ok=True)
try:
with open(config_path, 'w', encoding='utf-8') as f:
yaml.dump(config, f, default_flow_style=False, indent=2, sort_keys=True)
except Exception as e:
raise RuntimeError(f"Error saving configuration: {e}")
def get_default_config_path() -> Path:
"""Get the default configuration file path."""
# Try different locations in order of preference
possible_paths = [
Path.cwd() / "config.yaml",
Path.cwd() / "config.yml",
Path.home() / ".anydocs-mcp" / "config.yaml",
Path("/etc/anydocs-mcp/config.yaml"),
]
for path in possible_paths:
if path.exists():
return path
# Return the first option as default
return possible_paths[0]
def create_default_config() -> Dict[str, Any]:
"""Create a default configuration dictionary."""
return {
"server": {
"name": "AnyDocs-MCP",
"version": "0.1.0",
"description": "Transform documentation into MCP-compatible server",
"host": "localhost",
"port": 8000,
"debug": False,
"log_level": "INFO"
},
"database": {
"type": "sqlite",
"url": "sqlite:///./anydocs_mcp.db",
"echo": False
},
"auth": {
"method": "api_key",
"secret_key": "your-secret-key-change-this-in-production",
"access_token_expire_minutes": 30
},
"cache": {
"enabled": True,
"ttl": 3600,
"max_size": 1000
},
"rate_limit": {
"enabled": True,
"requests": 100,
"window": 60
},
"cors": {
"enabled": True,
"origins": ["*"],
"methods": ["GET", "POST", "PUT", "DELETE"]
},
"doc_sources": []
}
class Config:
env_prefix = 'AUTH_'
logger.info('环境变量加载结果', AUTH_METHODS=os.getenv('AUTH_METHODS'), AUTH_JWT_SECRET=os.getenv('AUTH_JWT_SECRET'))