#!/usr/bin/env python3
"""
Test LLM Providers
Test OpenAI and Anthropic providers with various scenarios
"""
import asyncio
import sys
import os
from datetime import datetime
# Add project root to path
project_root = os.path.join(os.path.dirname(__file__), '..')
sys.path.insert(0, project_root)
from src.llm_core.providers.factory import ProviderFactory, get_provider, health_check_providers
from src.llm_core.providers.base import LLMMessage, MessageRole
from src.config.settings import settings
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def test_provider_factory():
"""Test provider factory functionality"""
logger.info("π Testing Provider Factory...")
try:
# Test available providers
available = ProviderFactory.get_available_providers()
logger.info(f"Available providers: {list(available.keys())}")
for provider_name, info in available.items():
if info["available"]:
logger.info(f"β
{provider_name}: {info['default_model']} (models: {len(info['models'])})")
else:
logger.info(f"β {provider_name}: Not available")
return available
except Exception as e:
logger.error(f"β Provider factory test failed: {e}")
return {}
async def test_provider_health():
"""Test provider health checks"""
logger.info("π₯ Testing Provider Health Checks...")
try:
health_results = await health_check_providers()
for provider, is_healthy in health_results.items():
if is_healthy:
logger.info(f"β
{provider}: Healthy")
else:
logger.info(f"β {provider}: Unhealthy")
return health_results
except Exception as e:
logger.error(f"β Health check failed: {e}")
return {}
async def test_simple_chat(provider_name: str):
"""Test simple chat with a provider"""
logger.info(f"π¬ Testing simple chat with {provider_name}...")
try:
provider = get_provider(provider_name)
# Simple chat test
response = await provider.simple_chat(
"Say 'Hello from IRIS!' in exactly those words.",
max_tokens=20
)
logger.info(f"β
{provider_name} response: {response}")
return True
except Exception as e:
logger.error(f"β Simple chat with {provider_name} failed: {e}")
return False
async def test_conversation_chat(provider_name: str):
"""Test conversation with multiple messages"""
logger.info(f"π£οΈ Testing conversation with {provider_name}...")
try:
provider = get_provider(provider_name)
# Create conversation messages
messages = [
LLMMessage(role=MessageRole.SYSTEM, content="You are a helpful assistant. Keep responses brief."),
LLMMessage(role=MessageRole.USER, content="What is 2+2?"),
LLMMessage(role=MessageRole.ASSISTANT, content="2+2 equals 4."),
LLMMessage(role=MessageRole.USER, content="What about 3+3?")
]
response = await provider.chat(messages, max_tokens=20)
logger.info(f"β
{provider_name} conversation response: {response.content}")
logger.info(f" Usage: {response.usage}")
logger.info(f" Model: {response.model}")
return True
except Exception as e:
logger.error(f"β Conversation with {provider_name} failed: {e}")
return False
async def test_streaming_chat(provider_name: str):
"""Test streaming chat"""
logger.info(f"π Testing streaming with {provider_name}...")
try:
provider = get_provider(provider_name)
messages = [
LLMMessage(role=MessageRole.USER, content="Count from 1 to 5, one number per line.")
]
logger.info(f" {provider_name} streaming response:")
full_response = ""
async for chunk in provider.chat_stream(messages, max_tokens=50):
print(chunk, end="", flush=True)
full_response += chunk
print() # New line after streaming
logger.info(f"β
{provider_name} streaming complete. Full response: '{full_response.strip()}'")
return True
except Exception as e:
logger.error(f"β Streaming with {provider_name} failed: {e}")
return False
async def test_token_counting(provider_name: str):
"""Test token counting functionality"""
logger.info(f"π’ Testing token counting with {provider_name}...")
try:
provider = get_provider(provider_name)
test_texts = [
"Hello world!",
"This is a longer text to test token counting functionality.",
"The quick brown fox jumps over the lazy dog. " * 10
]
for text in test_texts:
token_count = provider.count_tokens(text)
logger.info(f" Text: '{text[:50]}...' -> {token_count} tokens")
# Test message estimation
messages = [
LLMMessage(role=MessageRole.SYSTEM, content="You are a helpful assistant."),
LLMMessage(role=MessageRole.USER, content="What is the capital of France?")
]
estimated_tokens = provider.estimate_tokens(messages)
logger.info(f" Message list -> {estimated_tokens} estimated tokens")
logger.info(f"β
{provider_name} token counting test complete")
return True
except Exception as e:
logger.error(f"β Token counting with {provider_name} failed: {e}")
return False
async def test_provider(provider_name: str):
"""Run all tests for a specific provider"""
logger.info(f"\nπ Testing {provider_name.upper()} Provider")
logger.info("=" * 50)
tests = [
("Simple Chat", test_simple_chat),
("Conversation", test_conversation_chat),
("Streaming", test_streaming_chat),
("Token Counting", test_token_counting)
]
results = {}
for test_name, test_func in tests:
try:
result = await test_func(provider_name)
results[test_name] = result
except Exception as e:
logger.error(f"β {test_name} test failed: {e}")
results[test_name] = False
# Summary
passed = sum(1 for result in results.values() if result)
total = len(results)
logger.info(f"\nπ {provider_name.upper()} Results: {passed}/{total} tests passed")
for test_name, result in results.items():
status = "β
" if result else "β"
logger.info(f" {status} {test_name}")
return results
async def main():
"""Run all LLM provider tests"""
try:
logger.info("π Starting LLM Provider Tests...")
logger.info(f"Default provider: {settings.llm_provider}")
# Test factory
available_providers = await test_provider_factory()
# Test health checks
health_results = await test_provider_health()
# Test each available provider
test_results = {}
for provider_name, info in available_providers.items():
if info.get("available") and health_results.get(provider_name):
test_results[provider_name] = await test_provider(provider_name)
else:
logger.info(f"\nβοΈ Skipping {provider_name} (not available or unhealthy)")
# Final summary
logger.info("\n" + "=" * 60)
logger.info("π LLM Provider Tests Complete!")
for provider_name, results in test_results.items():
passed = sum(1 for result in results.values() if result)
total = len(results)
logger.info(f" {provider_name}: {passed}/{total} tests passed")
return len(test_results) > 0
except Exception as e:
logger.error(f"β LLM provider tests failed: {e}")
return False
if __name__ == "__main__":
success = asyncio.run(main())
sys.exit(0 if success else 1)