"""
配置接口模块
提供配置文件加载、验证和管理功能。
"""
import os
import json
import yaml
from pathlib import Path
from typing import Any, Dict, Optional, Union, List
from dataclasses import dataclass, field
import copy
from middleware.error_handler import ConfigurationError
@dataclass
class TemplateConfig:
"""模板配置"""
name: str
description: str
path: str
variables: Dict[str, str] = field(default_factory=dict)
default: bool = False
@dataclass
class ServerConfig:
"""服务器配置"""
host: str = "localhost"
port: int = 8080
max_concurrent_requests: int = 10
request_timeout: int = 30
log_level: str = "INFO"
@dataclass
class SecurityConfig:
"""安全配置"""
allowed_paths: List[str] = field(default_factory=list)
max_file_size: int = 50 * 1024 * 1024 # 50MB
enable_path_validation: bool = True
enable_content_filtering: bool = True
session_timeout: int = 24 * 3600 # 24小时
@dataclass
class PerformanceConfig:
"""性能配置"""
cache_memory_size: int = 1000
cache_disk_enabled: bool = True
cache_disk_dir: str = "cache"
cache_disk_max_size: int = 100 * 1024 * 1024 # 100MB
parallel_workers: int = 4
async_timeout: int = 10
@dataclass
class AppConfig:
"""应用配置"""
server: ServerConfig = field(default_factory=ServerConfig)
security: SecurityConfig = field(default_factory=SecurityConfig)
performance: PerformanceConfig = field(default_factory=PerformanceConfig)
templates: Dict[str, TemplateConfig] = field(default_factory=dict)
folder_descriptions: Dict[str, str] = field(default_factory=dict)
class ConfigInterface:
"""配置接口类"""
def __init__(self, config_dir: str = "config"):
"""
初始化配置接口
Args:
config_dir: 配置文件目录
"""
self.config_dir = Path(config_dir)
self.config_dir.mkdir(parents=True, exist_ok=True)
self._config: Optional[AppConfig] = None
self._watchers = [] # 配置变化监听器
def load_config(self, config_file: Optional[str] = None) -> AppConfig:
"""
加载配置文件
Args:
config_file: 配置文件路径,默认为 config/default.yaml
Returns:
应用配置对象
"""
if config_file is None:
config_file = self.config_dir / "default.yaml"
else:
config_file = Path(config_file)
try:
if not config_file.exists():
# 如果配置文件不存在,创建默认配置
config = AppConfig()
self.save_config(config, config_file)
return config
# 读取配置文件
with open(config_file, 'r', encoding='utf-8') as f:
if config_file.suffix.lower() in ['.yaml', '.yml']:
data = yaml.safe_load(f)
elif config_file.suffix.lower() == '.json':
data = json.load(f)
else:
raise ConfigurationError(
f"不支持的配置文件格式: {config_file.suffix}",
config_key="config_file_format"
)
# 解析配置
config = self._parse_config(data)
self._config = config
# 通知监听器
self._notify_watchers(config)
return config
except (RuntimeError, ValueError) as e:
if isinstance(e, ConfigurationError):
raise
raise ConfigurationError(
f"加载配置文件失败: {str(e)}",
config_key=str(config_file)
)
def save_config(self, config: AppConfig, config_file: Optional[str] = None):
"""
保存配置文件
Args:
config: 应用配置对象
config_file: 配置文件路径
"""
if config_file is None:
config_file = self.config_dir / "default.yaml"
else:
config_file = Path(config_file)
try:
# 确保配置目录存在
config_file.parent.mkdir(parents=True, exist_ok=True)
# 转换为字典
data = self._config_to_dict(config)
# 写入配置文件
with open(config_file, 'w', encoding='utf-8') as f:
if config_file.suffix.lower() in ['.yaml', '.yml']:
yaml.dump(data, f, default_flow_style=False, allow_unicode=True, indent=2)
elif config_file.suffix.lower() == '.json':
json.dump(data, f, ensure_ascii=False, indent=2)
else:
raise ConfigurationError(
f"不支持的配置文件格式: {config_file.suffix}",
config_key="config_file_format"
)
except (RuntimeError, ValueError) as e:
if isinstance(e, ConfigurationError):
raise
raise ConfigurationError(
f"保存配置文件失败: {str(e)}",
config_key=str(config_file)
)
def get_config(self) -> AppConfig:
"""获取当前配置"""
if self._config is None:
return self.load_config()
return copy.deepcopy(self._config)
def update_config(self, updates: Dict[str, Any]):
"""更新配置"""
config = self.get_config()
self._update_nested_object(config, updates)
self._config = config
self.save_config(config)
self._notify_watchers(config)
def _parse_config(self, data: Dict[str, Any]) -> AppConfig:
"""解析配置字典为配置对象"""
config = AppConfig()
# 解析服务器配置
if 'server' in data:
server_data = data['server']
config.server = ServerConfig(
host=server_data.get('host', config.server.host),
port=server_data.get('port', config.server.port),
max_concurrent_requests=server_data.get('max_concurrent_requests', config.server.max_concurrent_requests),
request_timeout=server_data.get('request_timeout', config.server.request_timeout),
log_level=server_data.get('log_level', config.server.log_level)
)
# 解析安全配置
if 'security' in data:
security_data = data['security']
config.security = SecurityConfig(
allowed_paths=security_data.get('allowed_paths', config.security.allowed_paths),
max_file_size=security_data.get('max_file_size', config.security.max_file_size),
enable_path_validation=security_data.get('enable_path_validation', config.security.enable_path_validation),
enable_content_filtering=security_data.get('enable_content_filtering', config.security.enable_content_filtering),
session_timeout=security_data.get('session_timeout', config.security.session_timeout)
)
# 解析性能配置
if 'performance' in data:
perf_data = data['performance']
config.performance = PerformanceConfig(
cache_memory_size=perf_data.get('cache_memory_size', config.performance.cache_memory_size),
cache_disk_enabled=perf_data.get('cache_disk_enabled', config.performance.cache_disk_enabled),
cache_disk_dir=perf_data.get('cache_disk_dir', config.performance.cache_disk_dir),
cache_disk_max_size=perf_data.get('cache_disk_max_size', config.performance.cache_disk_max_size),
parallel_workers=perf_data.get('parallel_workers', config.performance.parallel_workers),
async_timeout=perf_data.get('async_timeout', config.performance.async_timeout)
)
# 解析模板配置
if 'templates' in data:
for name, template_data in data['templates'].items():
config.templates[name] = TemplateConfig(
name=template_data.get('name', name),
description=template_data.get('description', ''),
path=template_data.get('path', ''),
variables=template_data.get('variables', {}),
default=template_data.get('default', False)
)
# 解析文件夹描述
if 'folder_descriptions' in data:
config.folder_descriptions.update(data['folder_descriptions'])
return config
def _config_to_dict(self, config: AppConfig) -> Dict[str, Any]:
"""将配置对象转换为字典"""
return {
'server': {
'host': config.server.host,
'port': config.server.port,
'max_concurrent_requests': config.server.max_concurrent_requests,
'request_timeout': config.server.request_timeout,
'log_level': config.server.log_level
},
'security': {
'allowed_paths': config.security.allowed_paths,
'max_file_size': config.security.max_file_size,
'enable_path_validation': config.security.enable_path_validation,
'enable_content_filtering': config.security.enable_content_filtering,
'session_timeout': config.security.session_timeout
},
'performance': {
'cache_memory_size': config.performance.cache_memory_size,
'cache_disk_enabled': config.performance.cache_disk_enabled,
'cache_disk_dir': config.performance.cache_disk_dir,
'cache_disk_max_size': config.performance.cache_disk_max_size,
'parallel_workers': config.performance.parallel_workers,
'async_timeout': config.performance.async_timeout
},
'templates': {
name: {
'name': template.name,
'description': template.description,
'path': template.path,
'variables': template.variables,
'default': template.default
}
for name, template in config.templates.items()
},
'folder_descriptions': config.folder_descriptions
}
def _update_nested_object(self, obj: Any, updates: Dict[str, Any]):
"""递归更新嵌套对象"""
for key, value in updates.items():
if isinstance(value, dict) and hasattr(obj, key):
current_value = getattr(obj, key)
if isinstance(current_value, dict):
current_value.update(value)
else:
setattr(obj, key, value)
else:
setattr(obj, key, value)
def add_watcher(self, callback):
"""添加配置变化监听器"""
self._watchers.append(callback)
def _notify_watchers(self, config: AppConfig):
"""通知所有监听器"""
for watcher in self._watchers:
try:
watcher(config)
except (RuntimeError, ValueError) as e:
print(f"配置监听器出错: {e}")
def load_folder_descriptions(self, file_path: Optional[str] = None) -> Dict[str, str]:
"""
加载文件夹描述配置
Args:
file_path: 描述文件路径,默认为 config/folder_descriptions.json
Returns:
文件夹描述字典
"""
if file_path is None:
file_path = self.config_dir / "folder_descriptions.json"
else:
file_path = Path(file_path)
try:
if not file_path.exists():
# 创建默认的文件夹描述
descriptions = self._create_default_folder_descriptions()
self._save_folder_descriptions(descriptions, file_path)
return descriptions
with open(file_path, 'r', encoding='utf-8') as f:
descriptions = json.load(f)
return descriptions
except (RuntimeError, ValueError) as e:
raise ConfigurationError(
f"加载文件夹描述失败: {str(e)}",
config_key=str(file_path)
)
def _create_default_folder_descriptions(self) -> Dict[str, str]:
"""创建默认的文件夹描述"""
return {
"src": "源代码目录,包含项目的主要代码实现",
"lib": "库文件目录,存放第三方库或公共模块",
"docs": "文档目录,存放项目相关文档和说明",
"tests": "测试目录,存放单元测试和集成测试代码",
"config": "配置文件目录,存放应用配置文件",
"scripts": "脚本目录,存放自动化脚本和工具",
"assets": "资源目录,存放静态资源文件",
"utils": "工具函数目录,存放通用工具函数",
"components": "组件目录,存放可复用组件",
"services": "服务目录,存放业务逻辑服务",
"models": "模型目录,存放数据模型定义",
"views": "视图目录,存放界面视图代码",
"controllers": "控制器目录,存放业务控制逻辑",
"middleware": "中间件目录,存放中间件代码",
"api": "API目录,存放API接口代码",
"database": "数据库目录,存放数据库相关文件",
"migrations": "数据库迁移目录,存放数据库版本控制文件",
"seeds": "种子数据目录,存放初始数据文件",
"logs": "日志目录,存放应用程序日志",
"cache": "缓存目录,存放缓存文件",
"temp": "临时文件目录,存放临时文件",
"uploads": "上传文件目录,存放用户上传文件",
"backup": "备份目录,存放备份文件",
"public": "公共目录,存放公共访问文件",
"private": "私有目录,存放私有文件",
"static": "静态资源目录,存放静态文件",
"styles": "样式目录,存放CSS样式文件",
"images": "图片目录,存放图片资源",
"fonts": "字体目录,存放字体文件",
"videos": "视频目录,存放视频资源",
"audio": "音频目录,存放音频资源"
}
def _save_folder_descriptions(self, descriptions: Dict[str, str], file_path: Path):
"""保存文件夹描述到文件"""
file_path.parent.mkdir(parents=True, exist_ok=True)
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(descriptions, f, ensure_ascii=False, indent=2)
# 全局配置接口实例
config_interface = ConfigInterface()