coupling_analyzer.py•16.6 kB
"""
Code Coupling and Cohesion Analyzer
"""
import ast
import re
from pathlib import Path
from typing import Dict, Set, List, Tuple, Optional, Any
from collections import defaultdict
from .base_metrics import BaseMetricsAnalyzer, FileMetrics, ClassMetrics
class CouplingAnalyzer(BaseMetricsAnalyzer):
"""Analyzes coupling and cohesion metrics for code"""
def __init__(self, language: str = "python", project_root: Path = None):
"""
Initialize coupling analyzer
Args:
language: Programming language
project_root: Root directory for resolving imports
"""
super().__init__(language)
self.project_root = project_root or Path.cwd()
self.import_graph = defaultdict(set) # file -> set of dependencies
self.export_graph = defaultdict(set) # file -> set of exported items
self.class_dependencies = defaultdict(set) # class -> set of dependencies
def analyze_file(self, file_path: Path, content: str) -> FileMetrics:
"""
Analyze coupling metrics for a file
Args:
file_path: Path to the file
content: File content
Returns:
FileMetrics with coupling information
"""
metrics = FileMetrics(file_path=str(file_path))
if self.language == "python":
self._analyze_python_coupling(content, metrics, file_path)
elif self.language in ["javascript", "typescript"]:
self._analyze_js_coupling(content, metrics, file_path)
# Calculate instability
if metrics.coupling_afferent + metrics.coupling_efferent > 0:
metrics.instability = metrics.coupling_efferent / (
metrics.coupling_afferent + metrics.coupling_efferent
)
return metrics
def _analyze_python_coupling(self, content: str, metrics: FileMetrics, file_path: Path):
"""Analyze Python coupling using AST"""
try:
tree = ast.parse(content)
# Extract imports (efferent coupling)
imports = self._extract_python_imports(tree)
metrics.coupling_efferent = len(imports)
self.import_graph[str(file_path)] = imports
# Extract exported items
exports = self._extract_python_exports(tree)
self.export_graph[str(file_path)] = exports
# Analyze classes
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
class_metrics = self._analyze_python_class_coupling(node, content)
metrics.classes.append(class_metrics)
# Track class dependencies
class_deps = self._extract_class_dependencies(node)
self.class_dependencies[node.name] = class_deps
except SyntaxError:
# Fallback to regex analysis
imports = self._extract_imports_regex(content)
metrics.coupling_efferent = len(imports)
def _extract_python_imports(self, tree: ast.AST) -> Set[str]:
"""Extract all imports from Python AST"""
imports = set()
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
imports.add(alias.name.split('.')[0])
elif isinstance(node, ast.ImportFrom):
if node.module:
imports.add(node.module.split('.')[0])
return imports
def _extract_python_exports(self, tree: ast.AST) -> Set[str]:
"""Extract exported items (classes, functions) from Python AST"""
exports = set()
for node in tree.body:
if isinstance(node, ast.ClassDef):
exports.add(node.name)
elif isinstance(node, ast.FunctionDef):
# Only public functions (not starting with _)
if not node.name.startswith('_'):
exports.add(node.name)
elif isinstance(node, ast.Assign):
# Module-level variables
for target in node.targets:
if isinstance(target, ast.Name) and not target.id.startswith('_'):
exports.add(target.id)
return exports
def _analyze_python_class_coupling(self, node: ast.ClassDef, content: str) -> ClassMetrics:
"""Analyze coupling for a Python class"""
class_metrics = ClassMetrics(
name=node.name,
line_number=node.lineno
)
# Count methods and fields
methods = []
fields = set()
for item in node.body:
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
methods.append(item)
class_metrics.methods_count += 1
elif isinstance(item, ast.Assign):
for target in item.targets:
if isinstance(target, ast.Name):
fields.add(target.id)
class_metrics.fields_count += 1
# Calculate LCOM (Lack of Cohesion of Methods)
class_metrics.cohesion = self._calculate_lcom(methods, fields)
# Count base classes (efferent coupling for class)
class_metrics.coupling_efferent = len(node.bases)
# Inheritance depth (simplified - just count direct bases)
class_metrics.inheritance_depth = len(node.bases)
return class_metrics
def _extract_class_dependencies(self, node: ast.ClassDef) -> Set[str]:
"""Extract dependencies for a class"""
dependencies = set()
# Base classes
for base in node.bases:
if isinstance(base, ast.Name):
dependencies.add(base.id)
elif isinstance(base, ast.Attribute):
dependencies.add(base.attr)
# Used types in methods
for item in ast.walk(node):
if isinstance(item, ast.Name) and item.id[0].isupper():
# Likely a class reference
dependencies.add(item.id)
return dependencies
def _calculate_lcom(self, methods: List[ast.FunctionDef], fields: Set[str]) -> float:
"""
Calculate Lack of Cohesion of Methods (LCOM)
LCOM measures how well methods in a class use the class fields.
Lower values indicate better cohesion.
Args:
methods: List of method AST nodes
fields: Set of field names
Returns:
LCOM score (0-1, lower is better)
"""
if len(methods) <= 1 or len(fields) == 0:
return 0.0
# Count field usage by each method
method_field_usage = {}
for method in methods:
used_fields = set()
for node in ast.walk(method):
if isinstance(node, ast.Attribute):
if isinstance(node.value, ast.Name) and node.value.id == 'self':
if node.attr in fields:
used_fields.add(node.attr)
method_field_usage[method.name] = used_fields
# Calculate LCOM
# Count method pairs that don't share fields
no_shared_fields = 0
shared_fields = 0
method_names = list(method_field_usage.keys())
for i in range(len(method_names)):
for j in range(i + 1, len(method_names)):
fields_i = method_field_usage[method_names[i]]
fields_j = method_field_usage[method_names[j]]
if fields_i.intersection(fields_j):
shared_fields += 1
else:
no_shared_fields += 1
total_pairs = no_shared_fields + shared_fields
if total_pairs == 0:
return 0.0
# LCOM = (no_shared_fields - shared_fields) / total_pairs
# Normalized to 0-1 range
lcom = max(0, no_shared_fields - shared_fields) / total_pairs
return round(lcom, 3)
def _analyze_js_coupling(self, content: str, metrics: FileMetrics, file_path: Path):
"""Analyze JavaScript/TypeScript coupling using regex"""
# Extract imports
imports = self._extract_js_imports(content)
metrics.coupling_efferent = len(imports)
self.import_graph[str(file_path)] = imports
# Extract exports
exports = self._extract_js_exports(content)
self.export_graph[str(file_path)] = exports
# Find classes
class_pattern = r'(?:export\s+)?class\s+(\w+)(?:\s+extends\s+(\w+))?'
for match in re.finditer(class_pattern, content):
class_name = match.group(1)
base_class = match.group(2)
line_num = content[:match.start()].count('\n') + 1
class_metrics = ClassMetrics(
name=class_name,
line_number=line_num
)
if base_class:
class_metrics.inheritance_depth = 1
class_metrics.coupling_efferent = 1
metrics.classes.append(class_metrics)
def _extract_js_imports(self, content: str) -> Set[str]:
"""Extract JavaScript/TypeScript imports"""
imports = set()
# ES6 imports
import_pattern = r'import\s+.*?\s+from\s+[\'"]([^\'"]+)[\'"]'
for match in re.finditer(import_pattern, content):
module = match.group(1)
# Extract module name (first part of path)
if module.startswith('.'):
imports.add('local')
else:
imports.add(module.split('/')[0])
# CommonJS requires
require_pattern = r'require\s*\(\s*[\'"]([^\'"]+)[\'"]\s*\)'
for match in re.finditer(require_pattern, content):
module = match.group(1)
if module.startswith('.'):
imports.add('local')
else:
imports.add(module.split('/')[0])
return imports
def _extract_js_exports(self, content: str) -> Set[str]:
"""Extract JavaScript/TypeScript exports"""
exports = set()
# Named exports
export_pattern = r'export\s+(?:const|let|var|function|class)\s+(\w+)'
for match in re.finditer(export_pattern, content):
exports.add(match.group(1))
# Export statements
export_list_pattern = r'export\s*\{([^}]+)\}'
for match in re.finditer(export_list_pattern, content):
items = match.group(1).split(',')
for item in items:
# Handle 'name as alias' syntax
parts = item.strip().split(' as ')
exports.add(parts[0].strip())
# Default export
if re.search(r'export\s+default', content):
exports.add('default')
return exports
def _extract_imports_regex(self, content: str) -> Set[str]:
"""Fallback regex-based import extraction"""
imports = set()
if self.language == "python":
# Python imports
import_pattern = r'^\s*(?:from\s+(\S+)\s+)?import\s+(\S+)'
for match in re.finditer(import_pattern, content, re.MULTILINE):
module = match.group(1) or match.group(2)
imports.add(module.split('.')[0])
else:
# JavaScript/TypeScript
imports = self._extract_js_imports(content)
return imports
def calculate_afferent_coupling(self, file_path: str) -> int:
"""
Calculate afferent coupling (how many files depend on this file)
Args:
file_path: Path to the file
Returns:
Number of files that import this file
"""
count = 0
file_exports = self.export_graph.get(file_path, set())
if not file_exports:
return 0
# Check all other files' imports
for other_file, imports in self.import_graph.items():
if other_file == file_path:
continue
# Check if any export from this file is imported
if self._imports_from_file(other_file, file_path):
count += 1
return count
def _imports_from_file(self, importer: str, exported: str) -> bool:
"""Check if importer imports from exported file"""
# Simplified check - in real implementation would resolve paths properly
importer_path = Path(importer)
exported_path = Path(exported)
# Check relative imports
try:
relative = exported_path.relative_to(importer_path.parent)
relative_import = str(relative).replace('/', '.').replace('\\', '.')
imports = self.import_graph.get(importer, set())
return relative_import in imports or exported_path.stem in imports
except ValueError:
# Not a relative path
return exported_path.stem in self.import_graph.get(importer, set())
def analyze_project_coupling(self, files: List[Tuple[Path, str]]) -> Dict[str, Any]:
"""
Analyze coupling for entire project
Args:
files: List of (file_path, content) tuples
Returns:
Project-wide coupling metrics
"""
all_metrics = []
# First pass: collect all imports and exports
for file_path, content in files:
metrics = self.analyze_file(file_path, content)
all_metrics.append(metrics)
# Second pass: calculate afferent coupling
for metrics in all_metrics:
metrics.coupling_afferent = self.calculate_afferent_coupling(metrics.file_path)
# Recalculate instability with updated afferent coupling
if metrics.coupling_afferent + metrics.coupling_efferent > 0:
metrics.instability = metrics.coupling_efferent / (
metrics.coupling_afferent + metrics.coupling_efferent
)
# Calculate project-wide metrics
total_coupling = sum(m.coupling_efferent for m in all_metrics)
avg_instability = sum(m.instability for m in all_metrics) / len(all_metrics) if all_metrics else 0
# Find highly coupled files
highly_coupled = [
m for m in all_metrics
if m.coupling_efferent > 10 or m.coupling_afferent > 10
]
return {
'total_files': len(all_metrics),
'total_coupling': total_coupling,
'average_instability': avg_instability,
'highly_coupled_files': len(highly_coupled),
'metrics': all_metrics
}
def calculate_cyclomatic_complexity(self, code: str) -> int:
"""
Calculate cyclomatic complexity (delegated to complexity analyzer)
Args:
code: Source code
Returns:
Cyclomatic complexity
"""
# Basic implementation for coupling analyzer
complexity = 1
if self.language == "python":
complexity += len(re.findall(r'\bif\b', code))
complexity += len(re.findall(r'\bfor\b', code))
complexity += len(re.findall(r'\bwhile\b', code))
complexity += len(re.findall(r'\bexcept\b', code))
else:
complexity += len(re.findall(r'\bif\b', code))
complexity += len(re.findall(r'\bfor\b', code))
complexity += len(re.findall(r'\bwhile\b', code))
complexity += len(re.findall(r'\bcatch\b', code))
return complexity
def calculate_cognitive_complexity(self, code: str) -> int:
"""
Calculate cognitive complexity (simplified)
Args:
code: Source code
Returns:
Cognitive complexity
"""
# Simplified implementation for coupling analyzer
return self.calculate_cyclomatic_complexity(code)