import pytest
from unittest.mock import patch
import boto3
from moto import mock_athena, mock_glue, mock_s3
from src.asl_mcp_server.athena.client import AthenaClient
from src.asl_mcp_server.config.settings import Settings
@pytest.mark.integration
class TestAthenaIntegration:
"""Integration tests for Athena client functionality"""
@pytest.fixture
def integration_settings(self):
"""Settings for integration testing"""
return Settings(
aws_region="us-east-1",
security_lake_database="test_security_lake_db",
athena_workgroup="test-workgroup",
athena_output_location="s3://test-athena-results/",
log_level="DEBUG",
max_query_results=100,
query_timeout_seconds=60
)
@pytest.fixture
def mock_athena_client(self, integration_settings, mock_aws_services):
"""Create Athena client with mocked services"""
return AthenaClient(integration_settings)
@mock_athena
@mock_glue
@mock_s3
def test_athena_connection_test(self, integration_settings, aws_credentials):
"""Test Athena connection establishment"""
client = AthenaClient(integration_settings)
# Mock AWS services setup
athena_client = boto3.client("athena", region_name="us-east-1")
glue_client = boto3.client("glue", region_name="us-east-1")
# Create mock database
glue_client.create_database(
DatabaseInput={
"Name": "test_security_lake_db",
"Description": "Test Security Lake database"
}
)
# Test connection - should work with mocked services
# Note: This test validates the connection logic, actual AWS calls are mocked
with patch.object(client, '_get_client') as mock_get_client:
mock_athena = mock_get_client.return_value
mock_athena.list_databases.return_value = {
"DatabaseList": [{"Name": "test_security_lake_db"}]
}
# This should complete without errors
result = client._get_client()
assert result is not None
@pytest.mark.asyncio
async def test_query_execution_flow(self, mock_athena_client):
"""Test complete query execution flow"""
with patch.object(mock_athena_client, '_get_client') as mock_get_client:
mock_athena = mock_get_client.return_value
# Mock query execution flow
mock_athena.start_query_execution.return_value = {
"QueryExecutionId": "test-execution-id"
}
# Mock query status progression
status_responses = [
{"QueryExecution": {"Status": {"State": "RUNNING"}}},
{"QueryExecution": {"Status": {"State": "SUCCEEDED"}}}
]
mock_athena.get_query_execution.side_effect = status_responses
# Mock query results
mock_athena.get_query_results.return_value = {
"ResultSet": {
"Rows": [
{
"Data": [
{"VarCharValue": "time"},
{"VarCharValue": "src_ip"},
{"VarCharValue": "dst_ip"}
]
},
{
"Data": [
{"VarCharValue": "2024-01-15T10:30:00Z"},
{"VarCharValue": "192.168.1.100"},
{"VarCharValue": "203.0.113.45"}
]
}
]
}
}
# Execute query
results = await mock_athena_client.execute_query(
"SELECT time, src_ip, dst_ip FROM test_table LIMIT 10"
)
assert len(results) == 1
assert results[0]["time"] == "2024-01-15T10:30:00Z"
assert results[0]["src_ip"] == "192.168.1.100"
assert results[0]["dst_ip"] == "203.0.113.45"
@pytest.mark.asyncio
async def test_query_timeout_handling(self, mock_athena_client):
"""Test query timeout handling"""
# Set a very short timeout for testing
mock_athena_client.settings.query_timeout_seconds = 1
with patch.object(mock_athena_client, '_get_client') as mock_get_client:
mock_athena = mock_get_client.return_value
mock_athena.start_query_execution.return_value = {
"QueryExecutionId": "test-execution-id"
}
# Mock query that never completes
mock_athena.get_query_execution.return_value = {
"QueryExecution": {"Status": {"State": "RUNNING"}}
}
# Should raise timeout error
with pytest.raises(RuntimeError, match="Query timeout"):
await mock_athena_client.execute_query("SELECT * FROM test_table")
@pytest.mark.asyncio
async def test_query_failure_handling(self, mock_athena_client):
"""Test handling of failed queries"""
with patch.object(mock_athena_client, '_get_client') as mock_get_client:
mock_athena = mock_get_client.return_value
mock_athena.start_query_execution.return_value = {
"QueryExecutionId": "test-execution-id"
}
# Mock failed query
mock_athena.get_query_execution.return_value = {
"QueryExecution": {
"Status": {
"State": "FAILED",
"StateChangeReason": "Table not found"
}
}
}
# Should raise runtime error with failure reason
with pytest.raises(RuntimeError, match="Query failed"):
await mock_athena_client.execute_query("SELECT * FROM nonexistent_table")
@pytest.mark.asyncio
async def test_large_result_pagination(self, mock_athena_client):
"""Test handling of paginated query results"""
with patch.object(mock_athena_client, '_get_client') as mock_get_client:
mock_athena = mock_get_client.return_value
mock_athena.start_query_execution.return_value = {
"QueryExecutionId": "test-execution-id"
}
mock_athena.get_query_execution.return_value = {
"QueryExecution": {"Status": {"State": "SUCCEEDED"}}
}
# Mock paginated results
first_page = {
"ResultSet": {
"Rows": [
{"Data": [{"VarCharValue": "col1"}, {"VarCharValue": "col2"}]},
{"Data": [{"VarCharValue": "row1_col1"}, {"VarCharValue": "row1_col2"}]}
]
},
"NextToken": "next-token-123"
}
second_page = {
"ResultSet": {
"Rows": [
{"Data": [{"VarCharValue": "row2_col1"}, {"VarCharValue": "row2_col2"}]}
]
}
# No NextToken indicates end of results
}
mock_athena.get_query_results.side_effect = [first_page, second_page]
results = await mock_athena_client.execute_query("SELECT * FROM test_table")
assert len(results) == 2
assert results[0]["col1"] == "row1_col1"
assert results[1]["col1"] == "row2_col1"
@pytest.mark.asyncio
async def test_result_limit_enforcement(self, mock_athena_client):
"""Test that result limits are properly enforced"""
# Set a low limit for testing
mock_athena_client.settings.max_query_results = 2
with patch.object(mock_athena_client, '_get_client') as mock_get_client:
mock_athena = mock_get_client.return_value
mock_athena.start_query_execution.return_value = {
"QueryExecutionId": "test-execution-id"
}
mock_athena.get_query_execution.return_value = {
"QueryExecution": {"Status": {"State": "SUCCEEDED"}}
}
# Mock results with more rows than the limit
mock_athena.get_query_results.return_value = {
"ResultSet": {
"Rows": [
{"Data": [{"VarCharValue": "col1"}]}, # Header
{"Data": [{"VarCharValue": "row1"}]},
{"Data": [{"VarCharValue": "row2"}]},
{"Data": [{"VarCharValue": "row3"}]}, # Should be truncated
{"Data": [{"VarCharValue": "row4"}]} # Should be truncated
]
}
}
results = await mock_athena_client.execute_query("SELECT * FROM test_table")
# Should only return up to the limit
assert len(results) == 2
@pytest.mark.asyncio
async def test_list_databases_functionality(self, mock_athena_client):
"""Test database listing functionality"""
with patch.object(mock_athena_client, '_get_client') as mock_get_client:
mock_athena = mock_get_client.return_value
mock_athena.list_databases.return_value = {
"DatabaseList": [
{"Name": "test_security_lake_db"},
{"Name": "default"},
{"Name": "other_database"}
]
}
databases = await mock_athena_client.list_databases()
assert len(databases) == 3
assert "test_security_lake_db" in databases
assert "default" in databases
assert "other_database" in databases
@pytest.mark.asyncio
async def test_list_tables_functionality(self, mock_athena_client):
"""Test table listing functionality"""
with patch.object(mock_athena_client, '_get_client') as mock_get_client:
mock_athena = mock_get_client.return_value
mock_athena.list_table_metadata.return_value = {
"TableMetadataList": [
{
"Name": "amazon_security_lake_table_us_east_1_cloudtrail_mgmt_1_0",
"TableType": "EXTERNAL_TABLE",
"Columns": [
{"Name": "time", "Type": "timestamp"},
{"Name": "src_ip", "Type": "string"}
],
"Parameters": {
"location": "s3://aws-security-data-lake-us-east-1-123456789012/"
},
"LastAccessTime": "2024-01-15T10:00:00Z"
}
]
}
tables = await mock_athena_client.list_tables("test_security_lake_db")
assert len(tables) == 1
table = tables[0]
assert table["name"] == "amazon_security_lake_table_us_east_1_cloudtrail_mgmt_1_0"
assert table["type"] == "EXTERNAL_TABLE"
assert table["columns"] == 2
assert "s3://" in table["location"]
@pytest.mark.asyncio
async def test_credentials_error_handling(self, integration_settings):
"""Test handling of AWS credentials errors"""
client = AthenaClient(integration_settings)
# Mock credentials error
with patch.object(client, '_get_client') as mock_get_client:
from botocore.exceptions import NoCredentialsError
mock_get_client.side_effect = NoCredentialsError()
with pytest.raises(RuntimeError, match="AWS credentials not configured"):
await client.execute_query("SELECT 1")
@pytest.mark.asyncio
async def test_aws_service_error_handling(self, mock_athena_client):
"""Test handling of AWS service errors"""
with patch.object(mock_athena_client, '_get_client') as mock_get_client:
from botocore.exceptions import ClientError
error_response = {
"Error": {
"Code": "InvalidRequestException",
"Message": "Query syntax error"
}
}
mock_athena = mock_get_client.return_value
mock_athena.start_query_execution.side_effect = ClientError(
error_response, "StartQueryExecution"
)
with pytest.raises(RuntimeError, match="AWS error \\(InvalidRequestException\\)"):
await mock_athena_client.execute_query("INVALID SQL")
@pytest.mark.asyncio
async def test_connection_test_functionality(self, mock_athena_client):
"""Test connection test functionality"""
with patch.object(mock_athena_client, 'list_databases') as mock_list_db:
# Successful connection test
mock_list_db.return_value = ["test_db"]
result = await mock_athena_client.test_connection()
assert result is True
# Failed connection test
mock_list_db.side_effect = Exception("Connection failed")
result = await mock_athena_client.test_connection()
assert result is False