"""
Test script for Phase 1 implementation
Tests the extracted repository functions
"""
import asyncio
import logging
import os
from sqlalchemy import create_engine
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Import our new components
from repositories.postgres_repository import PostgresRepository
from repositories.vector_repository import VectorRepository
from repositories.embedding_repository import EmbeddingRepository
from services.schema_service import SchemaService
from services.sql_service import SQLService
from services.semantic_service import SemanticService
from services.synthesis_service import SynthesisService
from presentation.tools.schema_tools import SchemaTools
from presentation.tools.sql_tools import SQLTools
from presentation.tools.semantic_tools import SemanticTools
from presentation.tools.synthesis_tools import SynthesisTools
async def test_postgres_repository():
"""Test PostgresRepository functions"""
print("\n=== Testing PostgresRepository ===")
try:
# Create engine (you'll need to adjust connection string)
engine = create_engine("postgresql://postgres:password@localhost:5432/test_db")
repo = PostgresRepository(engine)
# Test table listing
print("Testing get_all_table_names...")
tables = repo.get_all_table_names()
print(f"Found {len(tables)} tables: {tables[:5]}...") # Show first 5
# Test SQL validation
print("\nTesting validate_tables_exist_in_sql...")
test_sql = "SELECT * FROM users"
error = repo.validate_tables_exist_in_sql(test_sql)
print(f"Validation result: {error or 'Valid'}")
# Test safe SQL execution
print("\nTesting execute_query...")
result = repo.execute_query("SELECT 1 as test_value")
print(f"Query result: success={result.success}, data={result.data}")
return True
except Exception as e:
print(f"PostgresRepository test failed: {e}")
return False
async def test_vector_repository():
"""Test VectorRepository functions"""
print("\n=== Testing VectorRepository ===")
try:
# Create engine (you'll need to adjust connection string)
engine = create_engine("postgresql://postgres:password@localhost:5432/test_db")
repo = VectorRepository(engine)
# Test vector extension check
print("Testing has_vector_extension...")
has_vector = repo.has_vector_extension()
print(f"Has vector extension: {has_vector}")
# Test text search fallback
print("\nTesting text_search_fallback...")
results = repo.text_search_fallback("test query", limit=3)
print(f"Text search results: {len(results)} results")
# Test embedding stats
print("\nTesting get_embedding_stats...")
stats = repo.get_embedding_stats()
print(f"Embedding stats: {stats}")
return True
except Exception as e:
print(f"VectorRepository test failed: {e}")
return False
async def test_services():
"""Test service layer"""
print("\n=== Testing Services ===")
try:
# Create engine (you'll need to adjust connection string)
engine = create_engine("postgresql://postgres:password@localhost:5432/test_db")
# Create repositories
postgres_repo = PostgresRepository(engine)
vector_repo = VectorRepository(engine)
# Create services
schema_service = SchemaService(postgres_repo)
sql_service = SQLService(postgres_repo, schema_service)
semantic_service = SemanticService(vector_repo)
synthesis_service = SynthesisService(llm_config={})
# Test schema service
print("Testing SchemaService...")
schema_info = schema_service.get_schema_info()
print(f"Schema info: {len(schema_info.tables)} tables")
# Test SQL service
print("\nTesting SQLService...")
result = sql_service.execute_safe("SELECT 1 as test")
print(f"SQL service result: {result.success}")
# Test semantic service
print("\nTesting SemanticService...")
search_results = semantic_service.search("test query", limit=3)
print(f"Semantic search: {len(search_results)} results")
# Test synthesis service
print("\nTesting SynthesisService...")
response = synthesis_service.synthesize_response(
question="test question",
sql_results=[{"test": "value"}],
semantic_results=[]
)
print(f"Synthesis response length: {len(response)}")
return True
except Exception as e:
print(f"Services test failed: {e}")
return False
async def test_mcp_tools():
"""Test MCP tool layer"""
print("\n=== Testing MCP Tools ===")
try:
# Create engine (you'll need to adjust connection string)
engine = create_engine("postgresql://postgres:password@localhost:5432/test_db")
# Create repositories
postgres_repo = PostgresRepository(engine)
vector_repo = VectorRepository(engine)
# Create services
schema_service = SchemaService(postgres_repo)
sql_service = SQLService(postgres_repo, schema_service)
semantic_service = SemanticService(vector_repo)
synthesis_service = SynthesisService(llm_config={})
# Create MCP tools
schema_tools = SchemaTools(schema_service)
sql_tools = SQLTools(sql_service)
semantic_tools = SemanticTools(semantic_service)
synthesis_tools = SynthesisTools(synthesis_service)
# Test schema tools
print("Testing SchemaTools...")
result = await schema_tools.get_schema_info()
print(f"Schema tool result: success={result['success']}")
# Test SQL tools
print("\nTesting SQLTools...")
result = await sql_tools.execute_sql("SELECT 1 as test")
print(f"SQL tool result: success={result['success']}")
# Test semantic tools
print("\nTesting SemanticTools...")
result = await semantic_tools.semantic_search("test query")
print(f"Semantic tool result: success={result['success']}")
# Test synthesis tools
print("\nTesting SynthesisTools...")
result = await synthesis_tools.synthesize_response(
question="test question",
sql_results=[{"test": "value"}]
)
print(f"Synthesis tool result: success={result['success']}")
return True
except Exception as e:
print(f"MCP Tools test failed: {e}")
return False
async def main():
"""Run all tests"""
print("Starting Phase 1 Implementation Tests")
print("=" * 50)
# Note: These tests require a database connection
# You may need to adjust connection strings or mock the database
test_results = []
# Test each component
test_results.append(await test_postgres_repository())
test_results.append(await test_vector_repository())
test_results.append(await test_services())
test_results.append(await test_mcp_tools())
# Summary
print("\n" + "=" * 50)
print("TEST SUMMARY")
print("=" * 50)
passed = sum(test_results)
total = len(test_results)
print(f"Tests passed: {passed}/{total}")
if passed == total:
print("✅ All Phase 1 components working correctly!")
else:
print("❌ Some tests failed - check connection strings and database setup")
return passed == total
if __name__ == "__main__":
# Run the tests
success = asyncio.run(main())
exit(0 if success else 1)