test_robustness.py•6.1 kB
"""Test robustness improvements to the MCP server."""
import asyncio
from fedmcp.server import call_tool, validate_limit, sanitize_error_message, logger
async def test_validation():
"""Test input validation."""
print("\n" + "=" * 60)
print("Testing Input Validation")
print("=" * 60)
# Test valid limit
assert validate_limit(10) == 10
print("✓ Valid limit (10) accepted")
# Test default limit
assert validate_limit(None, default=5) == 5
print("✓ Default limit (5) applied when None")
# Test exceeding max
try:
validate_limit(999, max_val=50)
print("✗ Should have rejected limit > max")
return False
except ValueError as e:
print(f"✓ Rejected invalid limit: {e}")
# Test wrong type
try:
validate_limit("not_an_int")
print("✗ Should have rejected non-integer")
return False
except ValueError as e:
print(f"✓ Rejected non-integer: {e}")
return True
async def test_error_sanitization():
"""Test error message sanitization."""
print("\n" + "=" * 60)
print("Testing Error Sanitization")
print("=" * 60)
# Test API key sanitization
error = Exception("Failed: api_key=secret123456")
sanitized = sanitize_error_message(error)
assert "secret123456" not in sanitized
assert "api_key=***" in sanitized
print(f"✓ API key sanitized: {sanitized}")
# Test token sanitization
error = Exception("Error with token: Bearer abc123")
sanitized = sanitize_error_message(error)
assert "abc123" not in sanitized
print(f"✓ Token sanitized: {sanitized}")
return True
async def test_tool_error_handling():
"""Test that tools handle errors gracefully."""
print("\n" + "=" * 60)
print("Testing Tool Error Handling")
print("=" * 60)
# Test list_votes with invalid limit
try:
result = await call_tool("list_votes", {"limit": 999})
response_text = result[0].text
if "Invalid input" in response_text and "between 1 and 100" in response_text:
print(f"✓ list_votes rejects invalid limit")
print(f" Response: {response_text[:100]}...")
else:
print(f"✗ Unexpected response: {response_text}")
return False
except Exception as e:
print(f"✗ Unexpected error: {e}")
return False
# Test get_bill with missing required parameter
try:
result = await call_tool("get_bill", {"code": "c-249"}) # Missing session
response_text = result[0].text
if "Missing required parameter" in response_text or "session" in response_text.lower():
print(f"✓ get_bill rejects missing required parameter")
print(f" Response: {response_text[:100]}...")
else:
print(f"✗ Unexpected response: {response_text}")
return False
except Exception as e:
print(f"✗ Unexpected error: {e}")
return False
# Test unknown tool
try:
result = await call_tool("nonexistent_tool", {})
response_text = result[0].text
if "Unknown tool" in response_text:
print(f"✓ Unknown tool handled gracefully")
print(f" Response: {response_text}")
else:
print(f"✗ Unexpected response: {response_text}")
return False
except Exception as e:
print(f"✗ Unexpected error: {e}")
return False
return True
async def test_valid_request():
"""Test that valid requests still work."""
print("\n" + "=" * 60)
print("Testing Valid Requests")
print("=" * 60)
# Test list_votes with small valid limit
try:
print("Requesting 3 recent votes...")
result = await call_tool("list_votes", {"limit": 3})
response_text = result[0].text
if "Error" not in response_text and "Invalid" not in response_text:
print(f"✓ list_votes works with valid input")
print(f" Response preview: {response_text[:150]}...")
return True
else:
print(f"⚠ Response may indicate an error: {response_text[:200]}")
# This might be an API issue, not our code
return True
except Exception as e:
print(f"⚠ API request failed (may be network/API issue): {e}")
return True # Not a code error
async def main():
"""Run all tests."""
print("\n" + "=" * 60)
print("FedMCP Robustness Test Suite")
print("=" * 60)
tests = [
("Input Validation", test_validation),
("Error Sanitization", test_error_sanitization),
("Tool Error Handling", test_tool_error_handling),
("Valid Requests", test_valid_request),
]
results = []
for test_name, test_func in tests:
try:
result = await test_func()
results.append((test_name, result))
except Exception as e:
print(f"\n✗ {test_name} failed with exception: {e}")
import traceback
traceback.print_exc()
results.append((test_name, False))
# Summary
print("\n" + "=" * 60)
print("Test Summary")
print("=" * 60)
passed = sum(1 for _, result in results if result)
total = len(results)
for test_name, result in results:
status = "✓ PASS" if result else "✗ FAIL"
print(f"{status}: {test_name}")
print("\n" + "=" * 60)
print(f"Results: {passed}/{total} tests passed")
print("=" * 60)
if passed == total:
print("\n🎉 All robustness tests passed!")
print("\nNext step: Restart Claude Desktop to test with the MCP server")
print("Then try these commands:")
print(" - 'list recent parliamentary votes'")
print(" - 'list 500 votes' (should show friendly error)")
print(" - 'search for bills about climate'")
else:
print("\n⚠ Some tests failed. Review the output above.")
return passed == total
if __name__ == "__main__":
success = asyncio.run(main())
exit(0 if success else 1)