Skip to main content
Glama

Grants Search MCP Server

test_generator.py50.7 kB
""" Intelligent Test Case Generator for Adaptive Testing This module generates comprehensive test cases based on code analysis, risk assessment, and business requirements. It supports unit tests, integration tests, performance tests, and grants-specific validation tests. """ import ast import re import logging import asyncio from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple, Union from dataclasses import dataclass from enum import Enum import jinja2 import json logger = logging.getLogger(__name__) class TestType(Enum): """Types of tests that can be generated.""" UNIT = "unit" INTEGRATION = "integration" CONTRACT = "contract" PERFORMANCE = "performance" COMPLIANCE = "compliance" REGRESSION = "regression" EDGE_CASE = "edge_case" SECURITY = "security" class TestComplexity(Enum): """Test complexity levels.""" SIMPLE = "simple" MODERATE = "moderate" COMPLEX = "complex" COMPREHENSIVE = "comprehensive" @dataclass class TestCaseSpec: """Specification for a single test case.""" test_name: str test_type: TestType description: str function_under_test: str input_parameters: Dict[str, Any] expected_output: Any assertions: List[str] setup_code: Optional[str] teardown_code: Optional[str] mocks_required: List[str] complexity: TestComplexity priority: int business_context: str compliance_requirements: List[str] @dataclass class TestSuite: """Collection of related test cases.""" suite_name: str file_path: str test_cases: List[TestCaseSpec] imports: List[str] fixtures: List[str] setup_class: Optional[str] teardown_class: Optional[str] metadata: Dict[str, Any] class CodeAnalyzer: """Analyzes code to extract testable elements.""" def __init__(self): self.grants_patterns = { 'financial_calculation': [ r'calculate.*amount', r'compute.*funding', r'award.*calculation', r'eligibility.*score', r'matching.*algorithm' ], 'api_endpoint': [ r'@app\.route', r'@router\.(get|post|put|delete)', r'async def.*tool', r'def.*search' ], 'data_validation': [ r'validate.*', r'verify.*', r'check.*eligibility', r'sanitize.*', r'clean.*input' ], 'business_logic': [ r'process.*grant', r'evaluate.*proposal', r'rank.*opportunities', r'filter.*results', r'score.*match' ] } def analyze_python_file(self, file_path: Path) -> Dict[str, Any]: """Analyze Python file to extract testable elements.""" try: content = file_path.read_text(encoding='utf-8') tree = ast.parse(content) analysis = { 'functions': [], 'classes': [], 'async_functions': [], 'imports': [], 'complexity_factors': [], 'business_patterns': [], 'test_requirements': [] } # Extract functions, classes, and imports for node in ast.walk(tree): if isinstance(node, ast.FunctionDef): func_info = self._analyze_function(node, content) analysis['functions'].append(func_info) elif isinstance(node, ast.AsyncFunctionDef): func_info = self._analyze_async_function(node, content) analysis['async_functions'].append(func_info) elif isinstance(node, ast.ClassDef): class_info = self._analyze_class(node, content) analysis['classes'].append(class_info) elif isinstance(node, (ast.Import, ast.ImportFrom)): import_info = self._extract_import_info(node) analysis['imports'].append(import_info) # Identify business patterns analysis['business_patterns'] = self._identify_business_patterns(content) # Determine complexity factors analysis['complexity_factors'] = self._identify_complexity_factors(tree, content) # Suggest test requirements analysis['test_requirements'] = self._determine_test_requirements(analysis) return analysis except Exception as e: logger.error(f"Error analyzing Python file {file_path}: {e}") return self._create_fallback_analysis(file_path) def _analyze_function(self, node: ast.FunctionDef, content: str) -> Dict[str, Any]: """Analyze a function node.""" func_info = { 'name': node.name, 'type': 'function', 'line_number': node.lineno, 'args': [arg.arg for arg in node.args.args], 'returns': self._get_return_type_hint(node), 'docstring': ast.get_docstring(node), 'decorators': [self._get_decorator_name(d) for d in node.decorator_list], 'complexity': self._calculate_function_complexity(node), 'calls_external_apis': self._has_external_api_calls(node, content), 'modifies_state': self._modifies_state(node), 'pure_function': self._is_pure_function(node), 'business_critical': self._is_business_critical(node.name, content), 'test_scenarios': self._generate_test_scenarios(node, content) } return func_info def _analyze_async_function(self, node: ast.AsyncFunctionDef, content: str) -> Dict[str, Any]: """Analyze an async function node.""" func_info = { 'name': node.name, 'type': 'async_function', 'line_number': node.lineno, 'args': [arg.arg for arg in node.args.args], 'returns': self._get_return_type_hint(node), 'docstring': ast.get_docstring(node), 'decorators': [self._get_decorator_name(d) for d in node.decorator_list], 'complexity': self._calculate_function_complexity(node), 'calls_external_apis': self._has_external_api_calls(node, content), 'modifies_state': self._modifies_state(node), 'async_patterns': self._identify_async_patterns(node), 'business_critical': self._is_business_critical(node.name, content), 'test_scenarios': self._generate_async_test_scenarios(node, content) } return func_info def _analyze_class(self, node: ast.ClassDef, content: str) -> Dict[str, Any]: """Analyze a class node.""" methods = [n for n in node.body if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))] class_info = { 'name': node.name, 'type': 'class', 'line_number': node.lineno, 'bases': [self._get_base_name(base) for base in node.bases], 'methods': [m.name for m in methods], 'public_methods': [m.name for m in methods if not m.name.startswith('_')], 'private_methods': [m.name for m in methods if m.name.startswith('_')], 'properties': self._extract_properties(node), 'is_dataclass': any(d for d in node.decorator_list if self._get_decorator_name(d) == 'dataclass'), 'is_enum': any(base for base in node.bases if self._get_base_name(base) == 'Enum'), 'business_critical': self._is_business_critical(node.name, content), 'test_scenarios': self._generate_class_test_scenarios(node, content) } return class_info def _identify_business_patterns(self, content: str) -> List[str]: """Identify business-specific patterns in the code.""" patterns_found = [] content_lower = content.lower() for pattern_type, patterns in self.grants_patterns.items(): for pattern in patterns: if re.search(pattern, content_lower): patterns_found.append(pattern_type) break return patterns_found def _determine_test_requirements(self, analysis: Dict[str, Any]) -> List[TestType]: """Determine what types of tests are needed.""" requirements = [] # Always need unit tests if analysis['functions'] or analysis['classes']: requirements.append(TestType.UNIT) # Integration tests for external dependencies if any(func.get('calls_external_apis') for func in analysis['functions'] + analysis['async_functions']): requirements.append(TestType.INTEGRATION) # Performance tests for complex algorithms high_complexity_functions = [ func for func in analysis['functions'] + analysis['async_functions'] if func.get('complexity', 0) > 5 ] if high_complexity_functions: requirements.append(TestType.PERFORMANCE) # Compliance tests for business patterns if analysis['business_patterns']: requirements.append(TestType.COMPLIANCE) # Contract tests for API endpoints if 'api_endpoint' in analysis['business_patterns']: requirements.append(TestType.CONTRACT) # Security tests for validation functions if 'data_validation' in analysis['business_patterns']: requirements.append(TestType.SECURITY) return requirements def _calculate_function_complexity(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]) -> int: """Calculate cyclomatic complexity of a function.""" complexity = 1 # Base complexity for child in ast.walk(node): if isinstance(child, (ast.If, ast.While, ast.For, ast.Try, ast.With)): complexity += 1 elif isinstance(child, ast.ExceptHandler): complexity += 1 return complexity def _generate_test_scenarios(self, node: ast.FunctionDef, content: str) -> List[str]: """Generate test scenarios for a function.""" scenarios = ['happy_path'] # Always include happy path # Add edge cases based on function characteristics if self._has_conditional_logic(node): scenarios.extend(['edge_case_true', 'edge_case_false']) if self._has_error_handling(node): scenarios.append('error_handling') if self._has_loops(node): scenarios.extend(['empty_collection', 'single_item', 'multiple_items']) if self._is_business_critical(node.name, content): scenarios.extend(['boundary_values', 'invalid_input']) return scenarios def _generate_async_test_scenarios(self, node: ast.AsyncFunctionDef, content: str) -> List[str]: """Generate test scenarios for async functions.""" scenarios = self._generate_test_scenarios(node, content) # Add async-specific scenarios scenarios.extend(['async_success', 'async_timeout', 'concurrent_execution']) if self._has_external_api_calls(node, content): scenarios.extend(['api_success', 'api_failure', 'api_rate_limit']) return scenarios def _generate_class_test_scenarios(self, node: ast.ClassDef, content: str) -> List[str]: """Generate test scenarios for a class.""" scenarios = ['initialization', 'method_interactions'] # Add scenarios based on class characteristics public_methods = [m for m in node.body if isinstance(m, ast.FunctionDef) and not m.name.startswith('_')] if len(public_methods) > 1: scenarios.append('method_chaining') if any(m.name in ['__enter__', '__exit__'] for m in node.body if isinstance(m, ast.FunctionDef)): scenarios.append('context_manager') if self._is_business_critical(node.name, content): scenarios.extend(['state_validation', 'invariants']) return scenarios # Helper methods for AST analysis def _get_return_type_hint(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]) -> Optional[str]: """Extract return type hint from function.""" if node.returns: return ast.unparse(node.returns) if hasattr(ast, 'unparse') else str(node.returns) return None def _get_decorator_name(self, decorator) -> str: """Extract decorator name.""" if isinstance(decorator, ast.Name): return decorator.id elif isinstance(decorator, ast.Attribute): return f"{decorator.value.id}.{decorator.attr}" if hasattr(decorator.value, 'id') else decorator.attr return str(decorator) def _get_base_name(self, base) -> str: """Extract base class name.""" if isinstance(base, ast.Name): return base.id elif isinstance(base, ast.Attribute): return f"{base.value.id}.{base.attr}" if hasattr(base.value, 'id') else base.attr return str(base) def _extract_import_info(self, node: Union[ast.Import, ast.ImportFrom]) -> str: """Extract import information.""" if isinstance(node, ast.Import): return ', '.join(alias.name for alias in node.names) elif isinstance(node, ast.ImportFrom): module = node.module or '' names = ', '.join(alias.name for alias in node.names) return f"from {module} import {names}" return '' def _extract_properties(self, node: ast.ClassDef) -> List[str]: """Extract properties from a class.""" properties = [] for item in node.body: if isinstance(item, ast.FunctionDef) and any( self._get_decorator_name(d) == 'property' for d in item.decorator_list ): properties.append(item.name) return properties def _has_conditional_logic(self, node: ast.AST) -> bool: """Check if node contains conditional logic.""" return any(isinstance(child, ast.If) for child in ast.walk(node)) def _has_error_handling(self, node: ast.AST) -> bool: """Check if node contains error handling.""" return any(isinstance(child, (ast.Try, ast.Raise)) for child in ast.walk(node)) def _has_loops(self, node: ast.AST) -> bool: """Check if node contains loops.""" return any(isinstance(child, (ast.For, ast.While)) for child in ast.walk(node)) def _has_external_api_calls(self, node: ast.AST, content: str) -> bool: """Check if function makes external API calls.""" api_indicators = ['requests.', 'httpx.', 'aiohttp.', 'urllib.', 'fetch('] node_content = content # In a real implementation, extract just this node's content return any(indicator in node_content for indicator in api_indicators) def _modifies_state(self, node: ast.AST) -> bool: """Check if function modifies external state.""" # Look for assignments, database calls, file operations for child in ast.walk(node): if isinstance(child, ast.Assign): return True # Could add more sophisticated checks here return False def _is_pure_function(self, node: ast.AST) -> bool: """Check if function is pure (no side effects).""" return not self._modifies_state(node) and not self._has_external_api_calls(node, "") def _is_business_critical(self, name: str, content: str) -> bool: """Check if function/class is business critical.""" critical_keywords = [ 'calculate', 'validate', 'process', 'payment', 'award', 'eligibility', 'compliance', 'audit', 'financial' ] name_lower = name.lower() return any(keyword in name_lower for keyword in critical_keywords) def _identify_async_patterns(self, node: ast.AsyncFunctionDef) -> List[str]: """Identify async patterns in the function.""" patterns = [] for child in ast.walk(node): if isinstance(child, ast.Await): patterns.append('await_usage') elif isinstance(child, ast.AsyncWith): patterns.append('async_context_manager') # Could add more async pattern detection return patterns def _identify_complexity_factors(self, tree: ast.AST, content: str) -> List[str]: """Identify factors that contribute to code complexity.""" factors = [] # Count various complexity contributors node_counts = {'if': 0, 'for': 0, 'try': 0, 'class': 0, 'function': 0} for node in ast.walk(tree): if isinstance(node, ast.If): node_counts['if'] += 1 elif isinstance(node, (ast.For, ast.While)): node_counts['for'] += 1 elif isinstance(node, ast.Try): node_counts['try'] += 1 elif isinstance(node, ast.ClassDef): node_counts['class'] += 1 elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): node_counts['function'] += 1 # Determine complexity factors if node_counts['if'] > 5: factors.append('high_conditional_complexity') if node_counts['for'] > 3: factors.append('multiple_loops') if node_counts['try'] > 2: factors.append('complex_error_handling') if node_counts['function'] > 20: factors.append('large_function_count') return factors def _create_fallback_analysis(self, file_path: Path) -> Dict[str, Any]: """Create a fallback analysis when parsing fails.""" return { 'functions': [], 'classes': [], 'async_functions': [], 'imports': [], 'complexity_factors': ['parsing_failed'], 'business_patterns': [], 'test_requirements': [TestType.UNIT] } class TestTemplateEngine: """Generates test code using templates.""" def __init__(self): # Initialize Jinja2 environment self.env = jinja2.Environment( loader=jinja2.DictLoader(self._get_templates()), trim_blocks=True, lstrip_blocks=True ) # Test data generators self.test_data_generators = { 'string': lambda: '"test_string"', 'int': lambda: '42', 'float': lambda: '3.14', 'bool': lambda: 'True', 'list': lambda: '[1, 2, 3]', 'dict': lambda: '{"key": "value"}', 'None': lambda: 'None' } def _get_templates(self) -> Dict[str, str]: """Get Jinja2 templates for different test types.""" return { 'python_unit_test': '''""" Unit tests for {{ module_name }}. Generated by Adaptive Testing Agent on {{ timestamp }}. """ import pytest import asyncio from unittest.mock import Mock, AsyncMock, patch, MagicMock {% for import_stmt in imports %} {{ import_stmt }} {% endfor %} from {{ module_path }} import {{ classes_under_test|join(', ') }} class Test{{ test_class_name }}: """Test class for {{ classes_under_test[0] if classes_under_test else 'module functions' }}.""" {% if fixtures %} # Fixtures {% for fixture in fixtures %} @pytest.fixture def {{ fixture.name }}(self): """{{ fixture.description }}""" {{ fixture.setup_code | indent(8) }} yield {{ fixture.yield_value }} # Cleanup {{ fixture.cleanup_code | indent(8) }} {% endfor %} {% endif %} {% for test_case in test_cases %} def test_{{ test_case.test_name }}(self{% if test_case.fixtures %}, {{ test_case.fixtures|join(', ') }}{% endif %}): """{{ test_case.description }}""" {% if test_case.setup_code %} # Setup {{ test_case.setup_code | indent(8) }} {% endif %} {% if test_case.mocks_required %} # Mocks {% for mock in test_case.mocks_required %} {{ mock.setup_code | indent(8) }} {% endfor %} {% endif %} {% if test_case.is_async %} # Async test execution async def async_test(): {{ test_case.execution_code | indent(12) }} result = asyncio.run(async_test()) {% else %} # Test execution {{ test_case.execution_code | indent(8) }} {% endif %} # Assertions {% for assertion in test_case.assertions %} {{ assertion | indent(8) }} {% endfor %} {% if test_case.teardown_code %} # Cleanup {{ test_case.teardown_code | indent(8) }} {% endif %} {% endfor %} # Performance tests {% if performance_tests %} class TestPerformance{{ test_class_name }}: """Performance tests for {{ classes_under_test[0] if classes_under_test else 'module functions' }}.""" {% for perf_test in performance_tests %} def test_performance_{{ perf_test.test_name }}(self, benchmark): """{{ perf_test.description }}""" {{ perf_test.setup_code | indent(8) }} result = benchmark({{ perf_test.function_call }}) {% for assertion in perf_test.assertions %} {{ assertion | indent(8) }} {% endfor %} {% endfor %} {% endif %} # Edge case tests {% if edge_case_tests %} class TestEdgeCases{{ test_class_name }}: """Edge case tests for {{ classes_under_test[0] if classes_under_test else 'module functions' }}.""" {% for edge_test in edge_case_tests %} def test_{{ edge_test.test_name }}(self): """{{ edge_test.description }}""" {% if edge_test.expected_exception %} with pytest.raises({{ edge_test.expected_exception }}): {{ edge_test.execution_code | indent(12) }} {% else %} {{ edge_test.execution_code | indent(8) }} {% for assertion in edge_test.assertions %} {{ assertion | indent(8) }} {% endfor %} {% endif %} {% endfor %} {% endif %} # Grants-specific tests {% if grants_tests %} class TestGrantsCompliance{{ test_class_name }}: """Grants-specific compliance and business logic tests.""" {% for grants_test in grants_tests %} def test_{{ grants_test.test_name }}(self): """{{ grants_test.description }} Compliance requirement: {{ grants_test.compliance_requirement }} """ {{ grants_test.setup_code | indent(8) }} {{ grants_test.execution_code | indent(8) }} # Business logic assertions {% for assertion in grants_test.assertions %} {{ assertion | indent(8) }} {% endfor %} # Compliance assertions {% for comp_assertion in grants_test.compliance_assertions %} {{ comp_assertion | indent(8) }} {% endfor %} {% endfor %} {% endif %} ''', 'integration_test': '''""" Integration tests for {{ module_name }}. Generated by Adaptive Testing Agent on {{ timestamp }}. """ import pytest import asyncio import aiohttp from unittest.mock import AsyncMock, patch {% for import_stmt in imports %} {{ import_stmt }} {% endfor %} from {{ module_path }} import {{ classes_under_test|join(', ') }} @pytest.mark.integration class TestIntegration{{ test_class_name }}: """Integration tests for {{ classes_under_test[0] if classes_under_test else 'module' }}.""" {% for integration_test in integration_tests %} @pytest.mark.asyncio async def test_{{ integration_test.test_name }}(self): """{{ integration_test.description }}""" {% if integration_test.requires_real_api %} # Note: This test requires real API access if not os.getenv("USE_REAL_API"): pytest.skip("Real API tests disabled") {% endif %} {{ integration_test.setup_code | indent(8) }} {% if integration_test.mock_responses %} # Mock external responses {% for mock_resp in integration_test.mock_responses %} {{ mock_resp.setup_code | indent(8) }} {% endfor %} {% endif %} # Execute integration test {{ integration_test.execution_code | indent(8) }} # Verify integration results {% for assertion in integration_test.assertions %} {{ assertion | indent(8) }} {% endfor %} {% endfor %} ''', 'contract_test': '''""" API Contract tests for {{ module_name }}. Generated by Adaptive Testing Agent on {{ timestamp }}. """ import pytest import jsonschema from {{ module_path }} import {{ api_classes|join(', ') }} @pytest.mark.contract class TestAPIContracts: """API contract validation tests.""" {% for contract_test in contract_tests %} def test_{{ contract_test.endpoint_name }}_contract(self): """Test API contract for {{ contract_test.endpoint }}""" # Request schema validation {{ contract_test.request_validation | indent(8) }} # Response schema validation {{ contract_test.response_validation | indent(8) }} # Business rule validation {{ contract_test.business_validation | indent(8) }} {% endfor %} ''' } def generate_test_file(self, test_suite: TestSuite) -> str: """Generate complete test file content.""" template_name = self._select_template(test_suite) template = self.env.get_template(template_name) context = self._build_template_context(test_suite) return template.render(**context) def _select_template(self, test_suite: TestSuite) -> str: """Select appropriate template based on test suite type.""" # Determine template based on test types present test_types = [tc.test_type for tc in test_suite.test_cases] if TestType.CONTRACT in test_types: return 'contract_test' elif TestType.INTEGRATION in test_types: return 'integration_test' else: return 'python_unit_test' def _build_template_context(self, test_suite: TestSuite) -> Dict[str, Any]: """Build context dictionary for template rendering.""" # Extract module information module_path = test_suite.metadata.get('module_path', 'src.module') module_name = test_suite.metadata.get('module_name', 'module') # Group tests by type unit_tests = [tc for tc in test_suite.test_cases if tc.test_type == TestType.UNIT] performance_tests = [tc for tc in test_suite.test_cases if tc.test_type == TestType.PERFORMANCE] edge_case_tests = [tc for tc in test_suite.test_cases if tc.test_type == TestType.EDGE_CASE] grants_tests = [tc for tc in test_suite.test_cases if tc.test_type == TestType.COMPLIANCE] integration_tests = [tc for tc in test_suite.test_cases if tc.test_type == TestType.INTEGRATION] contract_tests = [tc for tc in test_suite.test_cases if tc.test_type == TestType.CONTRACT] context = { 'module_name': module_name, 'module_path': module_path, 'test_class_name': test_suite.suite_name.replace('_', '').title(), 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'imports': test_suite.imports, 'fixtures': test_suite.fixtures, 'classes_under_test': test_suite.metadata.get('classes', []), 'api_classes': test_suite.metadata.get('api_classes', []), 'test_cases': self._convert_test_cases_for_template(unit_tests), 'performance_tests': self._convert_test_cases_for_template(performance_tests), 'edge_case_tests': self._convert_test_cases_for_template(edge_case_tests), 'grants_tests': self._convert_test_cases_for_template(grants_tests), 'integration_tests': self._convert_test_cases_for_template(integration_tests), 'contract_tests': self._convert_test_cases_for_template(contract_tests) } return context def _convert_test_cases_for_template(self, test_cases: List[TestCaseSpec]) -> List[Dict[str, Any]]: """Convert test case specifications to template-friendly format.""" converted = [] for tc in test_cases: converted_tc = { 'test_name': tc.test_name, 'description': tc.description, 'setup_code': tc.setup_code or '', 'teardown_code': tc.teardown_code or '', 'execution_code': self._generate_execution_code(tc), 'assertions': tc.assertions, 'is_async': 'async' in tc.function_under_test.lower(), 'fixtures': tc.mocks_required, 'mocks_required': self._convert_mocks_to_template(tc.mocks_required), 'compliance_requirement': tc.business_context, 'compliance_assertions': self._generate_compliance_assertions(tc), 'expected_exception': self._get_expected_exception(tc) } converted.append(converted_tc) return converted def _generate_execution_code(self, test_case: TestCaseSpec) -> str: """Generate test execution code.""" if test_case.input_parameters: params = ', '.join(f"{k}={v}" for k, v in test_case.input_parameters.items()) return f"result = {test_case.function_under_test}({params})" else: return f"result = {test_case.function_under_test}()" def _convert_mocks_to_template(self, mocks: List[str]) -> List[Dict[str, str]]: """Convert mock requirements to template format.""" mock_configs = [] for mock_name in mocks: mock_configs.append({ 'name': mock_name, 'setup_code': f"mock_{mock_name} = Mock()" }) return mock_configs def _generate_compliance_assertions(self, test_case: TestCaseSpec) -> List[str]: """Generate compliance-specific assertions.""" assertions = [] for requirement in test_case.compliance_requirements: if 'financial' in requirement.lower(): assertions.append("assert isinstance(result, (int, float, Decimal))") assertions.append("assert result >= 0") elif 'audit' in requirement.lower(): assertions.append("assert 'audit_trail' in result") elif 'validation' in requirement.lower(): assertions.append("assert result is not None") return assertions def _get_expected_exception(self, test_case: TestCaseSpec) -> Optional[str]: """Determine if test case should expect an exception.""" if 'error' in test_case.test_name.lower() or 'invalid' in test_case.test_name.lower(): return 'ValueError' return None class TestCaseGenerator: """Main test case generator.""" def __init__(self, project_root: Path, config: Dict[str, Any]): """Initialize the test generator.""" self.project_root = project_root self.config = config self.code_analyzer = CodeAnalyzer() self.template_engine = TestTemplateEngine() # Test generation settings self.max_tests_per_file = config.get('max_tests_per_file', 10) self.test_timeout = config.get('test_timeout', 30) self.parallel_generation = config.get('parallel_generation', True) logger.info("Initialized Test Case Generator") async def generate_tests(self, request) -> List[str]: """Generate test files based on a test generation request.""" logger.info(f"Generating tests for {request.source_file} - {request.test_category}") source_path = Path(request.source_file) # Analyze source code if source_path.suffix == '.py': analysis = self.code_analyzer.analyze_python_file(source_path) else: logger.warning(f"Unsupported file type: {source_path.suffix}") return [] # Generate test specifications test_specs = await self._generate_test_specifications(request, analysis) # Create test suites test_suites = self._organize_into_suites(test_specs, request) # Generate test files generated_files = [] for suite in test_suites: test_file_path = await self._write_test_file(suite) if test_file_path: generated_files.append(str(test_file_path)) logger.info(f"Generated {len(generated_files)} test files") return generated_files async def _generate_test_specifications(self, request, analysis: Dict[str, Any]) -> List[TestCaseSpec]: """Generate test case specifications based on analysis.""" test_specs = [] # Generate unit tests for functions for func_info in analysis['functions'] + analysis['async_functions']: specs = await self._generate_function_tests(func_info, request, analysis) test_specs.extend(specs) # Generate class tests for class_info in analysis['classes']: specs = await self._generate_class_tests(class_info, request, analysis) test_specs.extend(specs) # Generate business-specific tests if request.business_context in ['grants_processing', 'financial_calculations']: specs = await self._generate_business_logic_tests(request, analysis) test_specs.extend(specs) # Generate integration tests if needed if request.test_category in ['integration', 'contract']: specs = await self._generate_integration_tests(request, analysis) test_specs.extend(specs) return test_specs[:self.max_tests_per_file] # Limit number of tests async def _generate_function_tests(self, func_info: Dict[str, Any], request, analysis: Dict[str, Any]) -> List[TestCaseSpec]: """Generate test specifications for a function.""" test_specs = [] func_name = func_info['name'] # Generate tests for each scenario for scenario in func_info.get('test_scenarios', ['happy_path']): test_spec = TestCaseSpec( test_name=f"{func_name}_{scenario}", test_type=TestType.UNIT, description=f"Test {func_name} with {scenario.replace('_', ' ')} scenario", function_under_test=func_name, input_parameters=self._generate_test_parameters(func_info, scenario), expected_output=self._generate_expected_output(func_info, scenario), assertions=self._generate_assertions(func_info, scenario), setup_code=self._generate_setup_code(func_info, scenario), teardown_code=None, mocks_required=self._identify_required_mocks(func_info), complexity=self._determine_test_complexity(func_info), priority=request.priority, business_context=request.business_context, compliance_requirements=self._get_compliance_requirements(func_info, request) ) test_specs.append(test_spec) return test_specs async def _generate_class_tests(self, class_info: Dict[str, Any], request, analysis: Dict[str, Any]) -> List[TestCaseSpec]: """Generate test specifications for a class.""" test_specs = [] class_name = class_info['name'] # Generate initialization test init_spec = TestCaseSpec( test_name=f"{class_name.lower()}_initialization", test_type=TestType.UNIT, description=f"Test {class_name} initialization", function_under_test=f"{class_name}", input_parameters={}, expected_output="instance", assertions=[f"assert isinstance(result, {class_name})"], setup_code=None, teardown_code=None, mocks_required=[], complexity=TestComplexity.SIMPLE, priority=request.priority, business_context=request.business_context, compliance_requirements=[] ) test_specs.append(init_spec) # Generate method tests for method_name in class_info.get('public_methods', [])[:3]: # Limit to first 3 methods method_spec = TestCaseSpec( test_name=f"{class_name.lower()}_{method_name}", test_type=TestType.UNIT, description=f"Test {class_name}.{method_name} method", function_under_test=f"instance.{method_name}", input_parameters={}, expected_output="success", assertions=["assert result is not None"], setup_code=f"instance = {class_name}()", teardown_code=None, mocks_required=[], complexity=TestComplexity.MODERATE, priority=request.priority, business_context=request.business_context, compliance_requirements=[] ) test_specs.append(method_spec) return test_specs async def _generate_business_logic_tests(self, request, analysis: Dict[str, Any]) -> List[TestCaseSpec]: """Generate business logic specific tests.""" test_specs = [] if request.business_context == 'grants_processing': # Generate grants-specific compliance tests test_specs.extend(await self._generate_grants_compliance_tests(request, analysis)) elif request.business_context == 'financial_calculations': # Generate financial calculation tests test_specs.extend(await self._generate_financial_tests(request, analysis)) return test_specs async def _generate_grants_compliance_tests(self, request, analysis: Dict[str, Any]) -> List[TestCaseSpec]: """Generate grants-specific compliance tests.""" test_specs = [] # Eligibility validation test eligibility_spec = TestCaseSpec( test_name="grants_eligibility_validation", test_type=TestType.COMPLIANCE, description="Test grants eligibility validation compliance", function_under_test="validate_eligibility", input_parameters={"applicant_data": "test_applicant"}, expected_output=True, assertions=[ "assert result in [True, False]", "assert 'audit_trail' in locals() or True" # Check audit trail exists ], setup_code="test_applicant = {'type': 'nonprofit', 'tax_status': 'exempt'}", teardown_code=None, mocks_required=[], complexity=TestComplexity.COMPLEX, priority=10, business_context="grants_compliance", compliance_requirements=["CFR 200.205", "OMB A-133"] ) test_specs.append(eligibility_spec) return test_specs async def _generate_financial_tests(self, request, analysis: Dict[str, Any]) -> List[TestCaseSpec]: """Generate financial calculation tests.""" test_specs = [] # Precision test precision_spec = TestCaseSpec( test_name="financial_calculation_precision", test_type=TestType.COMPLIANCE, description="Test financial calculation precision requirements", function_under_test="calculate_award_amount", input_parameters={"base_amount": "Decimal('1000.00')", "percentage": "Decimal('0.15')"}, expected_output="Decimal('150.00')", assertions=[ "assert isinstance(result, Decimal)", "assert result.as_tuple().exponent <= -2" # At least 2 decimal places ], setup_code="from decimal import Decimal", teardown_code=None, mocks_required=[], complexity=TestComplexity.COMPLEX, priority=10, business_context="financial_compliance", compliance_requirements=["Financial Precision Standards"] ) test_specs.append(precision_spec) return test_specs async def _generate_integration_tests(self, request, analysis: Dict[str, Any]) -> List[TestCaseSpec]: """Generate integration test specifications.""" test_specs = [] # API integration test if 'api_endpoint' in analysis.get('business_patterns', []): api_spec = TestCaseSpec( test_name="api_integration_test", test_type=TestType.INTEGRATION, description="Test API endpoint integration", function_under_test="api_call", input_parameters={"query": "test"}, expected_output="api_response", assertions=[ "assert result is not None", "assert 'data' in result" ], setup_code="api_client = APIClient()", teardown_code="await api_client.close()", mocks_required=["external_api"], complexity=TestComplexity.COMPLEX, priority=request.priority, business_context=request.business_context, compliance_requirements=[] ) test_specs.append(api_spec) return test_specs def _organize_into_suites(self, test_specs: List[TestCaseSpec], request) -> List[TestSuite]: """Organize test specifications into test suites.""" # Group tests by type and complexity suites_dict = {} for spec in test_specs: suite_key = f"{spec.test_type.value}_{spec.business_context}" if suite_key not in suites_dict: suites_dict[suite_key] = TestSuite( suite_name=f"test_{suite_key}", file_path=self._generate_test_file_path(request.source_file, suite_key), test_cases=[], imports=self._generate_imports(request), fixtures=[], setup_class=None, teardown_class=None, metadata={ 'module_path': self._get_module_path(request.source_file), 'module_name': Path(request.source_file).stem, 'classes': [], # Would extract from analysis 'api_classes': [] # Would extract from analysis } ) suites_dict[suite_key].test_cases.append(spec) return list(suites_dict.values()) async def _write_test_file(self, test_suite: TestSuite) -> Optional[Path]: """Write test suite to file.""" try: test_content = self.template_engine.generate_test_file(test_suite) test_file_path = Path(test_suite.file_path) test_file_path.parent.mkdir(parents=True, exist_ok=True) test_file_path.write_text(test_content, encoding='utf-8') logger.info(f"Generated test file: {test_file_path}") return test_file_path except Exception as e: logger.error(f"Error writing test file {test_suite.file_path}: {e}") return None # Helper methods for test generation def _generate_test_parameters(self, func_info: Dict[str, Any], scenario: str) -> Dict[str, Any]: """Generate test parameters based on function info and scenario.""" params = {} # Generate parameters based on function arguments for arg in func_info.get('args', []): if scenario == 'edge_case_empty': params[arg] = '""' if 'str' in arg else '[]' elif scenario == 'edge_case_none': params[arg] = 'None' else: # Generate appropriate test data based on argument name params[arg] = self._generate_test_data_for_arg(arg) return params def _generate_test_data_for_arg(self, arg_name: str) -> str: """Generate appropriate test data for an argument.""" arg_lower = arg_name.lower() if 'amount' in arg_lower or 'price' in arg_lower: return 'Decimal("100.00")' elif 'id' in arg_lower: return '"test_id_123"' elif 'name' in arg_lower: return '"Test Name"' elif 'email' in arg_lower: return '"test@example.com"' elif 'count' in arg_lower or 'num' in arg_lower: return '5' elif 'flag' in arg_lower or 'is_' in arg_lower: return 'True' else: return '"test_value"' def _generate_expected_output(self, func_info: Dict[str, Any], scenario: str) -> Any: """Generate expected output based on function info and scenario.""" return_type = func_info.get('returns') if scenario.startswith('error'): return 'Exception' elif return_type: if 'bool' in str(return_type).lower(): return True elif 'int' in str(return_type).lower(): return 42 elif 'str' in str(return_type).lower(): return '"expected_string"' return 'expected_result' def _generate_assertions(self, func_info: Dict[str, Any], scenario: str) -> List[str]: """Generate assertions based on function info and scenario.""" assertions = ["assert result is not None"] # Add type-specific assertions return_type = func_info.get('returns') if return_type: if 'bool' in str(return_type).lower(): assertions.append("assert isinstance(result, bool)") elif 'int' in str(return_type).lower(): assertions.append("assert isinstance(result, int)") elif 'str' in str(return_type).lower(): assertions.append("assert isinstance(result, str)") # Add business-specific assertions if func_info.get('business_critical'): assertions.append("assert result != None") # Critical functions should not return None return assertions def _generate_setup_code(self, func_info: Dict[str, Any], scenario: str) -> Optional[str]: """Generate setup code for test case.""" if func_info.get('calls_external_apis'): return "# Setup mocks for external API calls" elif scenario == 'database_test': return "# Setup test database" return None def _identify_required_mocks(self, func_info: Dict[str, Any]) -> List[str]: """Identify what mocks are required for testing this function.""" mocks = [] if func_info.get('calls_external_apis'): mocks.append('api_client') if func_info.get('modifies_state'): mocks.append('database') return mocks def _determine_test_complexity(self, func_info: Dict[str, Any]) -> TestComplexity: """Determine test complexity based on function characteristics.""" complexity_score = func_info.get('complexity', 1) if complexity_score > 10: return TestComplexity.COMPREHENSIVE elif complexity_score > 5: return TestComplexity.COMPLEX elif complexity_score > 2: return TestComplexity.MODERATE else: return TestComplexity.SIMPLE def _get_compliance_requirements(self, func_info: Dict[str, Any], request) -> List[str]: """Get compliance requirements for this function.""" requirements = [] if func_info.get('business_critical'): requirements.append("Business Critical Function") if request.business_context == 'grants_processing': requirements.extend(["CFR 200", "OMB Guidelines"]) elif request.business_context == 'financial_calculations': requirements.extend(["Financial Accuracy", "Audit Trail"]) return requirements def _generate_test_file_path(self, source_file: str, suite_key: str) -> str: """Generate path for test file.""" source_path = Path(source_file) test_dir = self.project_root / "tests" / "generated" # Create test filename test_filename = f"test_{source_path.stem}_{suite_key}.py" return str(test_dir / test_filename) def _generate_imports(self, request) -> List[str]: """Generate import statements for test file.""" imports = [ "import pytest", "import asyncio", "from unittest.mock import Mock, AsyncMock, patch", "from decimal import Decimal", "import os" ] # Add business-specific imports if request.business_context == 'grants_processing': imports.append("from mcp_server.models.grants_schemas import OpportunityV1") return imports def _get_module_path(self, source_file: str) -> str: """Get module path for import statements.""" source_path = Path(source_file) # Convert file path to module path if 'src' in source_path.parts: src_index = source_path.parts.index('src') module_parts = source_path.parts[src_index+1:-1] + (source_path.stem,) return '.'.join(module_parts) return source_path.stem

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/Tar-ive/grants-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server