test_tools.py•11 kB
#!/usr/bin/env python3
"""Test script for Travel Company MCP Server tools"""
import sys
from pathlib import Path
# Add project to path
sys.path.insert(0, str(Path(__file__).parent))
from src.database import Database, CustomerDB, TripDB, RequestDB
def test_customer_search():
"""Test customer search functionality"""
print("\n" + "="*60)
print("TEST: Customer Search")
print("="*60)
db = Database("data/travel_company.db")
customer_db = CustomerDB(db.conn)
# Test 1: Search by name
print("\n1. Searching for customers named 'John'...")
results = customer_db.search("John", "name")
print(f" Found {len(results)} customers")
if results:
print("\n First 3 results:")
for customer in results[:3]:
print(f" - ID: {customer['customer_id']}, Name: {customer['name']}, Email: {customer['email']}")
# Verify all results contain "John"
all_have_john = all("john" in customer['name'].lower() for customer in results)
print(f"\n ✓ All results contain 'John': {all_have_john}")
if not all_have_john:
print(" ✗ ERROR: Some results don't contain 'John'!")
for customer in results:
if "john" not in customer['name'].lower():
print(f" - Wrong result: {customer['name']}")
# Test 2: Search by email
print("\n2. Searching for customers with email containing 'smith'...")
results = customer_db.search("smith", "email")
print(f" Found {len(results)} customers")
if results:
print(f" First result: {results[0]['email']}")
all_have_smith = all("smith" in customer['email'].lower() for customer in results)
print(f" ✓ All results contain 'smith': {all_have_smith}")
# Test 3: Search by customer_id
print("\n3. Searching for customer ID 5...")
results = customer_db.search("5", "customer_id")
print(f" Found {len(results)} customer(s)")
if results:
print(f" Result: ID={results[0]['customer_id']}, Name={results[0]['name']}")
is_correct_id = results[0]['customer_id'] == 5
print(f" ✓ Correct ID returned: {is_correct_id}")
db.close()
def test_customer_profile():
"""Test customer profile retrieval"""
print("\n" + "="*60)
print("TEST: Customer Profile")
print("="*60)
db = Database("data/travel_company.db")
customer_db = CustomerDB(db.conn)
print("\n1. Getting profile for customer ID 1...")
profile = customer_db.get_profile(1)
if profile:
print(f" Name: {profile['name']}")
print(f" Email: {profile['email']}")
print(f" Loyalty Tier: {profile['loyalty_tier']}")
print(f" Total Trips: {profile['statistics']['total_trips']}")
print(f" Lifetime Spending: ${profile['statistics']['lifetime_spending']:,.2f}")
print(f" Last Trip: {profile['statistics']['last_trip_date']}")
# Verify statistics structure
has_stats = 'statistics' in profile
has_all_fields = all(key in profile['statistics'] for key in ['total_trips', 'lifetime_spending', 'last_trip_date'])
print(f"\n ✓ Has statistics: {has_stats}")
print(f" ✓ Has all stat fields: {has_all_fields}")
else:
print(" ✗ ERROR: No profile returned!")
# Test non-existent customer
print("\n2. Testing non-existent customer ID 99999...")
profile = customer_db.get_profile(99999)
is_none = profile is None
print(f" ✓ Returns None for non-existent: {is_none}")
db.close()
def test_trip_search():
"""Test trip search functionality"""
print("\n" + "="*60)
print("TEST: Trip Search")
print("="*60)
db = Database("data/travel_company.db")
trip_db = TripDB(db.conn)
# Test 1: Search by destination
print("\n1. Searching for trips to 'Paris'...")
results = trip_db.search(destination="Paris")
print(f" Found {len(results)} trips")
if results:
print(f" First 3 destinations:")
for trip in results[:3]:
print(f" - {trip['destination']} (${trip['cost']:.2f})")
all_have_paris = all("paris" in trip['destination'].lower() for trip in results)
print(f"\n ✓ All results contain 'Paris': {all_have_paris}")
if not all_have_paris:
print(" ✗ ERROR: Some results don't contain 'Paris'!")
for trip in results:
if "paris" not in trip['destination'].lower():
print(f" - Wrong result: {trip['destination']}")
# Test 2: Search by status
print("\n2. Searching for upcoming trips...")
results = trip_db.search(status="upcoming", limit=10)
print(f" Found {len(results)} upcoming trips")
if results:
all_upcoming = all(trip['status'] == 'upcoming' for trip in results)
print(f" ✓ All results are upcoming: {all_upcoming}")
if not all_upcoming:
print(" ✗ ERROR: Some results have wrong status!")
for trip in results:
if trip['status'] != 'upcoming':
print(f" - Wrong status: {trip['status']}")
# Test 3: Search by date range
print("\n3. Searching for trips in 2024...")
results = trip_db.search(start_date="2024-01-01", end_date="2024-12-31", limit=10)
print(f" Found {len(results)} trips in 2024")
if results:
print(f" Date range: {results[0]['start_date']} to {results[-1]['start_date']}")
db.close()
def test_trip_history():
"""Test trip history retrieval"""
print("\n" + "="*60)
print("TEST: Trip History")
print("="*60)
db = Database("data/travel_company.db")
trip_db = TripDB(db.conn)
print("\n1. Getting trip history for customer ID 1...")
trips = trip_db.get_by_customer(1)
print(f" Found {len(trips)} trips")
if trips:
print(f"\n First 3 trips:")
for trip in trips[:3]:
print(f" - {trip['destination']} ({trip['start_date']}) - ${trip['cost']:.2f} - {trip['status']}")
# Verify all trips belong to customer 1
all_correct_customer = all(trip['customer_id'] == 1 for trip in trips)
print(f"\n ✓ All trips belong to customer 1: {all_correct_customer}")
if not all_correct_customer:
print(" ✗ ERROR: Some trips belong to wrong customer!")
# Test customer with no trips
print("\n2. Testing customer with potentially no trips (ID 999)...")
trips = trip_db.get_by_customer(999)
print(f" Found {len(trips)} trips")
print(f" ✓ Handles non-existent customer gracefully")
db.close()
def test_request_search():
"""Test request search functionality"""
print("\n" + "="*60)
print("TEST: Request Search")
print("="*60)
db = Database("data/travel_company.db")
request_db = RequestDB(db.conn)
# Test 1: Search by status
print("\n1. Searching for pending requests...")
results = request_db.search(status="pending")
print(f" Found {len(results)} pending requests")
if results:
print(f"\n First 3 requests:")
for req in results[:3]:
print(f" - {req['name']} ({req['email']}) - {req['destination_interest']}")
all_pending = all(req['status'] == 'pending' for req in results)
print(f"\n ✓ All results are pending: {all_pending}")
if not all_pending:
print(" ✗ ERROR: Some results have wrong status!")
# Test 2: Search by destination
print("\n2. Searching for requests about 'Paris'...")
results = request_db.search(destination="Paris")
print(f" Found {len(results)} requests")
if results:
all_have_paris = all("paris" in req['destination_interest'].lower() for req in results)
print(f" ✓ All results contain 'Paris': {all_have_paris}")
# Test 3: Search by email
print("\n3. Searching for requests with email containing 'smith'...")
results = request_db.search(email="smith")
print(f" Found {len(results)} requests")
if results:
print(f" First email: {results[0]['email']}")
db.close()
def test_pending_requests():
"""Test pending requests retrieval"""
print("\n" + "="*60)
print("TEST: Pending Requests (Last N Days)")
print("="*60)
db = Database("data/travel_company.db")
request_db = RequestDB(db.conn)
# Test different time ranges
for days in [7, 30, 60]:
print(f"\n{days} days back:")
requests = request_db.get_pending(days_back=days)
print(f" Found {len(requests)} pending requests")
if requests:
# Verify all are pending
all_pending = all(req['status'] == 'pending' for req in requests)
print(f" ✓ All are pending: {all_pending}")
db.close()
def test_data_integrity():
"""Test data integrity and relationships"""
print("\n" + "="*60)
print("TEST: Data Integrity")
print("="*60)
db = Database("data/travel_company.db")
cursor = db.conn.cursor()
# Count records
cursor.execute("SELECT COUNT(*) FROM customers")
customer_count = cursor.fetchone()[0]
cursor.execute("SELECT COUNT(*) FROM trips")
trip_count = cursor.fetchone()[0]
cursor.execute("SELECT COUNT(*) FROM requests")
request_count = cursor.fetchone()[0]
print(f"\nDatabase Statistics:")
print(f" Customers: {customer_count}")
print(f" Trips: {trip_count}")
print(f" Requests: {request_count}")
# Check for orphaned trips
cursor.execute("""
SELECT COUNT(*) FROM trips t
WHERE NOT EXISTS (
SELECT 1 FROM customers c WHERE c.customer_id = t.customer_id
)
""")
orphaned_trips = cursor.fetchone()[0]
print(f"\n ✓ Orphaned trips: {orphaned_trips} (should be 0)")
# Check data variety
cursor.execute("SELECT COUNT(DISTINCT destination) FROM trips")
unique_destinations = cursor.fetchone()[0]
print(f" Unique destinations: {unique_destinations}")
cursor.execute("SELECT COUNT(DISTINCT loyalty_tier) FROM customers")
unique_tiers = cursor.fetchone()[0]
print(f" Loyalty tiers: {unique_tiers}")
db.close()
def run_all_tests():
"""Run all tests"""
print("\n" + "#"*60)
print("# TRAVEL COMPANY MCP SERVER - TOOL TESTS")
print("#"*60)
try:
test_customer_search()
test_customer_profile()
test_trip_search()
test_trip_history()
test_request_search()
test_pending_requests()
test_data_integrity()
print("\n" + "#"*60)
print("# ALL TESTS COMPLETED")
print("#"*60)
print("\n✅ Test suite finished successfully!")
except Exception as e:
print(f"\n❌ ERROR: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
sys.exit(run_all_tests())