test_recursive_resolution.py•8.2 kB
#!/usr/bin/env python3
"""
Test recursive type resolution for MCP Proto Server.
"""
import sys
import json
import logging
from proto_indexer import ProtoIndex
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)
def print_section(title: str):
"""Print a section header."""
print("\n" + "=" * 80)
print(f" {title}")
print("=" * 80 + "\n")
def test_message_resolution():
"""Test recursive message resolution."""
print_section("TEST 1: Message Recursive Resolution")
index = ProtoIndex()
index.index_directory("examples/")
# Test with CreateUserRequest (references UserRole enum)
print("Getting CreateUserRequest with recursive resolution...")
message = index.get_message("CreateUserRequest", resolve_types=True, max_depth=10)
if message:
print(f"\n✓ Message: {message['full_name']}")
print(f" Fields: {len(message['fields'])}")
if 'resolved_types' in message:
print(f"\n Resolved Types ({len(message['resolved_types'])}):")
for type_name, type_def in message['resolved_types'].items():
print(f" - {type_name} ({type_def['kind']})")
if type_def['kind'] == 'enum':
print(f" Values: {', '.join([v['name'] for v in type_def['values']])}")
elif type_def['kind'] == 'message':
print(f" Fields: {', '.join([f['name'] for f in type_def['fields']])}")
else:
print(" No nested types found")
# Show JSON structure
print("\n Complete JSON Response:")
print(json.dumps(message, indent=2)[:500] + "...")
else:
print("✗ Message not found")
return index
def test_service_resolution(index: ProtoIndex):
"""Test recursive service resolution."""
print_section("TEST 2: Service Recursive Resolution")
print("Getting UserService with recursive resolution...")
service = index.get_service("UserService", resolve_types=True, max_depth=10)
if service:
print(f"\n✓ Service: {service['full_name']}")
print(f" RPCs: {len(service['rpcs'])}")
if 'resolved_types' in service:
print(f"\n Resolved Types ({len(service['resolved_types'])}):")
for type_name, type_def in service['resolved_types'].items():
print(f" - {type_name} ({type_def['kind']})")
if type_def['kind'] == 'message':
print(f" Fields: {len(type_def['fields'])}")
else:
print(" No resolved types")
# Count total definitions returned
total_defs = 1 + len(service.get('resolved_types', {}))
print(f"\n Total definitions in single response: {total_defs}")
print(" ✓ AI agent gets EVERYTHING in ONE call!")
else:
print("✗ Service not found")
def test_depth_control(index: ProtoIndex):
"""Test max_depth parameter."""
print_section("TEST 3: Depth Control")
print("Testing with different max_depth values...")
for depth in [0, 1, 5, 10]:
message = index.get_message("User", resolve_types=True, max_depth=depth)
if message:
resolved_count = len(message.get('resolved_types', {}))
print(f" max_depth={depth}: {resolved_count} types resolved")
def test_performance():
"""Test performance with large proto repo."""
print_section("TEST 4: Performance Test (Optional)")
import time
import os
# Use environment variable if set, otherwise skip this test
large_proto_dir = os.environ.get('LARGE_PROTO_DIR')
if not large_proto_dir:
print("⚠ Skipping performance test (set LARGE_PROTO_DIR env var to enable)")
print(" Example: export LARGE_PROTO_DIR=/path/to/large/proto/repo")
return None
print(f"Indexing proto files from: {large_proto_dir}")
start = time.time()
index = ProtoIndex()
index.index_directory(large_proto_dir)
index_time = time.time() - start
stats = index.get_stats()
print(f"✓ Indexed in {index_time:.2f}s")
print(f" Services: {stats['total_services']}")
print(f" Messages: {stats['total_messages']}")
print(f" Enums: {stats['total_enums']}")
# Test resolution on a random service
print("\nTesting resolution on a service...")
# Get first service
if stats['total_services'] > 0:
service_name = list(index.services.keys())[0]
print(f" Resolving: {service_name}")
start = time.time()
service = index.get_service(service_name, resolve_types=True, max_depth=10)
resolve_time = time.time() - start
if service:
resolved_count = len(service.get('resolved_types', {}))
total_size = len(json.dumps(service))
print(f" ✓ Resolved in {resolve_time*1000:.2f}ms")
print(f" Resolved types: {resolved_count}")
print(f" Response size: {total_size/1024:.1f} KB")
print(f"\n 🚀 Single call vs {resolved_count + 1} round trips!")
print(f" Efficiency gain: {resolved_count}x fewer requests")
return index
def compare_approaches(index: ProtoIndex):
"""Compare old vs new approach."""
print_section("TEST 5: Old vs New Approach Comparison")
service_name = list(index.services.keys())[0] if index.services else None
if not service_name:
print("No services to test")
return
print(f"Testing with: {service_name}\n")
# Old approach (without resolution)
print("OLD APPROACH (multiple round trips):")
service = index.get_service(service_name, resolve_types=False)
if service:
print(f" 1. Get service → {service_name}")
req_resp_types = set()
for rpc in service['rpcs']:
req_resp_types.add(rpc['request_type'])
req_resp_types.add(rpc['response_type'])
print(f" 2. Get {len(req_resp_types)} request/response types")
# Simulate getting all types
nested_count = 0
for type_name in list(req_resp_types)[:3]: # Sample 3
msg = index.get_message(type_name, resolve_types=False)
if msg:
nested_count += len([f for f in msg['fields'] if f['type'] not in ['string', 'int32', 'int64', 'bool']])
print(f" 3. Get ~{nested_count * len(req_resp_types) // 3} nested types")
print(f" TOTAL: ~{1 + len(req_resp_types) + nested_count * len(req_resp_types) // 3} round trips 🐌")
print("\nNEW APPROACH (single call with resolution):")
service = index.get_service(service_name, resolve_types=True, max_depth=10)
if service:
resolved_count = len(service.get('resolved_types', {}))
print(f" 1. Get service with ALL types resolved")
print(f" TOTAL: 1 round trip 🚀")
print(f"\n ✅ {resolved_count}x more efficient!")
def main():
"""Run all tests."""
try:
print("\n" + "=" * 80)
print(" MCP Proto Server - Recursive Resolution Test")
print("=" * 80)
# Test with examples
index = test_message_resolution()
test_service_resolution(index)
test_depth_control(index)
# Test with real files (if configured)
real_index = test_performance()
if real_index:
compare_approaches(real_index)
print_section("✅ All Tests Passed!")
print("The recursive resolution feature is working correctly.")
print("\nKey Benefits:")
print(" ✓ Single API call instead of multiple round trips")
print(" ✓ Complete type information in one response")
print(" ✓ Configurable depth to prevent over-fetching")
print(" ✓ Automatic cycle detection")
print(" ✓ Significant efficiency improvement for AI agents")
except Exception as e:
logger.error(f"Test failed: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()