import os
import re
from pathlib import Path
from typing import Dict, List, Tuple, Set
class ExceptionOptimizer:
def __init__(self):
self.context_mappings = {
'file': ['FileNotFoundError', 'PermissionError', 'OSError', 'IOError'],
'cache': ['CacheError', 'ValueError', 'RuntimeError'],
'config': ['ValueError', 'TypeError', 'ConfigError'],
'template': ['TemplateError', 'ValueError', 'RuntimeError'],
'network': ['ConnectionError', 'TimeoutError', 'HTTPError'],
'data': ['ValueError', 'TypeError', 'KeyError', 'IndexError'],
'security': ['SecurityError', 'ValueError', 'PermissionError'],
'version': ['ValueError', 'TypeError', 'VersionError'],
'default': ['ValueError', 'RuntimeError', 'TypeError']
}
self.function_mappings = {
'open': ['FileNotFoundError', 'PermissionError', 'OSError'],
'read': ['IOError', 'PermissionError'],
'write': ['IOError', 'PermissionError'],
'parse': ['ValueError', 'TypeError'],
'validate': ['ValueError', 'TypeError'],
'connect': ['ConnectionError', 'TimeoutError'],
'get': ['KeyError', 'IndexError'],
'set': ['ValueError', 'TypeError']
}
def get_context_from_file_path(self, filepath: str) -> str:
"""根据文件路径推断上下文"""
filepath_lower = filepath.lower()
if 'cache' in filepath_lower:
return 'cache'
elif 'config' in filepath_lower:
return 'config'
elif 'template' in filepath_lower:
return 'template'
elif 'security' in filepath_lower:
return 'security'
elif 'version' in filepath_lower:
return 'version'
elif 'file' in filepath_lower or 'path' in filepath_lower:
return 'file'
else:
return 'default'
def get_specific_exceptions(self, filepath: str, function_content: str, line_content: str) -> List[str]:
"""获取具体的异常类型"""
context = self.get_context_from_file_path(filepath)
base_exceptions = self.context_mappings[context].copy()
# 根据函数内容添加特定异常
for func_name, exceptions in self.function_mappings.items():
if func_name in function_content.lower():
for exc in exceptions:
if exc not in base_exceptions:
base_exceptions.append(exc)
# 根据具体行内容添加异常
if 'json' in line_content.lower():
base_exceptions.extend(['JSONDecodeError', 'ValueError'])
if 'yaml' in line_content.lower() or 'yml' in line_content.lower():
base_exceptions.extend(['YAMLError', 'ValueError'])
if 'path' in line_content.lower():
base_exceptions.extend(['PathError', 'ValueError'])
# 去重并排序
return sorted(list(set(base_exceptions)))
def optimize_exception_in_file(self, filepath: Path) -> Tuple[int, int]:
"""优化单个文件中的异常处理"""
try:
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
lines = content.split('\n')
except:
return 0, 0
optimized_lines = []
changes = 0
i = 0
while i < len(lines):
line = lines[i]
# 查找过于宽泛的异常处理
if re.search(r'except\s+Exception\s+as\s+\w+:', line) or 'Exception' in line:
# 获取函数上下文
context_start = max(0, i - 10)
context_end = min(len(lines), i + 5)
function_context = '\n'.join(lines[context_start:context_end])
# 获取具体异常类型
specific_exceptions = self.get_specific_exceptions(
str(filepath),
function_context,
line
)
# 替换为具体异常
if specific_exceptions:
indent = len(line) - len(line.lstrip())
exception_str = ', '.join(specific_exceptions[:3]) # 最多3个异常
new_line = ' ' * indent + f'except ({exception_str}) as e:'
optimized_lines.append(new_line)
changes += 1
else:
optimized_lines.append(line)
else:
optimized_lines.append(line)
i += 1
# 保存优化后的文件
if changes > 0:
with open(filepath, 'w', encoding='utf-8') as f:
f.write('\n'.join(optimized_lines))
return changes, 1 if changes > 0 else 0
def optimize_project(self, project_root: str) -> Dict:
"""优化整个项目的异常处理"""
project_path = Path(project_root)
results = {
'files_processed': 0,
'files_optimized': 0,
'total_optimizations': 0,
'optimized_files': []
}
# 遍历Python文件
for root, dirs, files in os.walk(project_path):
# 跳过虚拟环境和缓存目录
dirs[:] = [d for d in dirs if d not in ['__pycache__', '.git', 'venv', 'mcp-env']]
for file in files:
if file.endswith('.py'):
filepath = Path(root) / file
results['files_processed'] += 1
# 优化文件
optimizations, file_optimized = self.optimize_exception_in_file(filepath)
results['total_optimizations'] += optimizations
results['files_optimized'] += file_optimized
if optimizations > 0:
results['optimized_files'].append({
'file': str(filepath.relative_to(project_path)),
'optimizations': optimizations
})
return results
def generate_optimization_report(self, results: Dict) -> str:
"""生成优化报告"""
report = []
report.append("=" * 60)
report.append("异常处理优化报告")
report.append("=" * 60)
report.append(f"处理文件数: {results['files_processed']}")
report.append(f"优化文件数: {results['files_optimized']}")
report.append(f"总优化数: {results['total_optimizations']}")
report.append("")
if results['optimized_files']:
report.append("已优化的文件:")
for file_info in results['optimized_files']:
report.append(f" - {file_info['file']} ({file_info['optimizations']}个优化)")
report.append("")
report.append("优化策略:")
report.append(" 1. 根据文件上下文推断异常类型")
report.append(" 2. 根据函数内容选择特定异常")
report.append(" 3. 避免过于宽泛的Exception捕获")
report.append(" 4. 保持异常处理的精确性")
return "\n".join(report)
if __name__ == "__main__":
optimizer = ExceptionOptimizer()
results = optimizer.optimize_project("src")
report = optimizer.generate_optimization_report(results)
print(report)
# 保存优化报告
with open("exception_optimization_report.txt", "w", encoding="utf-8") as f:
f.write(report)