#!/usr/bin/env python3
"""
Script to fix PyContracts imports throughout the codebase.
Updates all imports to use our compatibility shim instead of the
problematic PyContracts library.
"""
import os
import re
from pathlib import Path
def fix_contract_imports(src_dir):
"""Fix contract imports in all Python files."""
# Pattern to match various contract import patterns
patterns = [
(r'from contracts import (.+)', r'from ..utils.contracts_shim import \1'),
(r'from \.\.utils\.contracts_shim import (.+)', r'from ..utils.contracts_shim import \1'), # Already fixed
]
# Special cases for different directory levels
special_patterns = {
'src/utils/': [
(r'from contracts import (.+)', r'from .contracts_shim import \1'),
],
'src/contracts/': [
(r'from contracts import (.+)', r'from ..utils.contracts_shim import \1'),
],
'src/core/': [
(r'from contracts import (.+)', r'from ..utils.contracts_shim import \1'),
],
'src/boundaries/': [
(r'from contracts import (.+)', r'from ..utils.contracts_shim import \1'),
],
'src/validators/': [
(r'from contracts import (.+)', r'from ..utils.contracts_shim import \1'),
],
'src/types/': [
(r'from contracts import (.+)', r'from ..utils.contracts_shim import \1'),
],
}
files_fixed = 0
for root, dirs, files in os.walk(src_dir):
for file in files:
if file.endswith('.py') and file != 'contracts_shim.py':
file_path = Path(root) / file
relative_path = str(file_path.relative_to(Path(src_dir).parent))
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
original_content = content
# Determine which patterns to use based on directory
patterns_to_use = patterns
for dir_prefix, dir_patterns in special_patterns.items():
if relative_path.startswith(dir_prefix):
patterns_to_use = dir_patterns
break
# Apply fixes
for pattern, replacement in patterns_to_use:
content = re.sub(pattern, replacement, content)
# Write back if changed
if content != original_content:
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
print(f"Fixed: {relative_path}")
files_fixed += 1
except Exception as e:
print(f"Error processing {file_path}: {e}")
print(f"Total files fixed: {files_fixed}")
if __name__ == "__main__":
src_dir = Path(__file__).parent / "src"
fix_contract_imports(src_dir)