Skip to main content
Glama
kebabmane

Amazon Security Lake MCP Server

by kebabmane
test_athena_integration.py13.8 kB
import pytest from unittest.mock import patch import boto3 from moto import mock_athena, mock_glue, mock_s3 from src.asl_mcp_server.athena.client import AthenaClient from src.asl_mcp_server.config.settings import Settings @pytest.mark.integration class TestAthenaIntegration: """Integration tests for Athena client functionality""" @pytest.fixture def integration_settings(self): """Settings for integration testing""" return Settings( aws_region="us-east-1", security_lake_database="test_security_lake_db", athena_workgroup="test-workgroup", athena_output_location="s3://test-athena-results/", log_level="DEBUG", max_query_results=100, query_timeout_seconds=60 ) @pytest.fixture def mock_athena_client(self, integration_settings, mock_aws_services): """Create Athena client with mocked services""" return AthenaClient(integration_settings) @mock_athena @mock_glue @mock_s3 def test_athena_connection_test(self, integration_settings, aws_credentials): """Test Athena connection establishment""" client = AthenaClient(integration_settings) # Mock AWS services setup athena_client = boto3.client("athena", region_name="us-east-1") glue_client = boto3.client("glue", region_name="us-east-1") # Create mock database glue_client.create_database( DatabaseInput={ "Name": "test_security_lake_db", "Description": "Test Security Lake database" } ) # Test connection - should work with mocked services # Note: This test validates the connection logic, actual AWS calls are mocked with patch.object(client, '_get_client') as mock_get_client: mock_athena = mock_get_client.return_value mock_athena.list_databases.return_value = { "DatabaseList": [{"Name": "test_security_lake_db"}] } # This should complete without errors result = client._get_client() assert result is not None @pytest.mark.asyncio async def test_query_execution_flow(self, mock_athena_client): """Test complete query execution flow""" with patch.object(mock_athena_client, '_get_client') as mock_get_client: mock_athena = mock_get_client.return_value # Mock query execution flow mock_athena.start_query_execution.return_value = { "QueryExecutionId": "test-execution-id" } # Mock query status progression status_responses = [ {"QueryExecution": {"Status": {"State": "RUNNING"}}}, {"QueryExecution": {"Status": {"State": "SUCCEEDED"}}} ] mock_athena.get_query_execution.side_effect = status_responses # Mock query results mock_athena.get_query_results.return_value = { "ResultSet": { "Rows": [ { "Data": [ {"VarCharValue": "time"}, {"VarCharValue": "src_ip"}, {"VarCharValue": "dst_ip"} ] }, { "Data": [ {"VarCharValue": "2024-01-15T10:30:00Z"}, {"VarCharValue": "192.168.1.100"}, {"VarCharValue": "203.0.113.45"} ] } ] } } # Execute query results = await mock_athena_client.execute_query( "SELECT time, src_ip, dst_ip FROM test_table LIMIT 10" ) assert len(results) == 1 assert results[0]["time"] == "2024-01-15T10:30:00Z" assert results[0]["src_ip"] == "192.168.1.100" assert results[0]["dst_ip"] == "203.0.113.45" @pytest.mark.asyncio async def test_query_timeout_handling(self, mock_athena_client): """Test query timeout handling""" # Set a very short timeout for testing mock_athena_client.settings.query_timeout_seconds = 1 with patch.object(mock_athena_client, '_get_client') as mock_get_client: mock_athena = mock_get_client.return_value mock_athena.start_query_execution.return_value = { "QueryExecutionId": "test-execution-id" } # Mock query that never completes mock_athena.get_query_execution.return_value = { "QueryExecution": {"Status": {"State": "RUNNING"}} } # Should raise timeout error with pytest.raises(RuntimeError, match="Query timeout"): await mock_athena_client.execute_query("SELECT * FROM test_table") @pytest.mark.asyncio async def test_query_failure_handling(self, mock_athena_client): """Test handling of failed queries""" with patch.object(mock_athena_client, '_get_client') as mock_get_client: mock_athena = mock_get_client.return_value mock_athena.start_query_execution.return_value = { "QueryExecutionId": "test-execution-id" } # Mock failed query mock_athena.get_query_execution.return_value = { "QueryExecution": { "Status": { "State": "FAILED", "StateChangeReason": "Table not found" } } } # Should raise runtime error with failure reason with pytest.raises(RuntimeError, match="Query failed"): await mock_athena_client.execute_query("SELECT * FROM nonexistent_table") @pytest.mark.asyncio async def test_large_result_pagination(self, mock_athena_client): """Test handling of paginated query results""" with patch.object(mock_athena_client, '_get_client') as mock_get_client: mock_athena = mock_get_client.return_value mock_athena.start_query_execution.return_value = { "QueryExecutionId": "test-execution-id" } mock_athena.get_query_execution.return_value = { "QueryExecution": {"Status": {"State": "SUCCEEDED"}} } # Mock paginated results first_page = { "ResultSet": { "Rows": [ {"Data": [{"VarCharValue": "col1"}, {"VarCharValue": "col2"}]}, {"Data": [{"VarCharValue": "row1_col1"}, {"VarCharValue": "row1_col2"}]} ] }, "NextToken": "next-token-123" } second_page = { "ResultSet": { "Rows": [ {"Data": [{"VarCharValue": "row2_col1"}, {"VarCharValue": "row2_col2"}]} ] } # No NextToken indicates end of results } mock_athena.get_query_results.side_effect = [first_page, second_page] results = await mock_athena_client.execute_query("SELECT * FROM test_table") assert len(results) == 2 assert results[0]["col1"] == "row1_col1" assert results[1]["col1"] == "row2_col1" @pytest.mark.asyncio async def test_result_limit_enforcement(self, mock_athena_client): """Test that result limits are properly enforced""" # Set a low limit for testing mock_athena_client.settings.max_query_results = 2 with patch.object(mock_athena_client, '_get_client') as mock_get_client: mock_athena = mock_get_client.return_value mock_athena.start_query_execution.return_value = { "QueryExecutionId": "test-execution-id" } mock_athena.get_query_execution.return_value = { "QueryExecution": {"Status": {"State": "SUCCEEDED"}} } # Mock results with more rows than the limit mock_athena.get_query_results.return_value = { "ResultSet": { "Rows": [ {"Data": [{"VarCharValue": "col1"}]}, # Header {"Data": [{"VarCharValue": "row1"}]}, {"Data": [{"VarCharValue": "row2"}]}, {"Data": [{"VarCharValue": "row3"}]}, # Should be truncated {"Data": [{"VarCharValue": "row4"}]} # Should be truncated ] } } results = await mock_athena_client.execute_query("SELECT * FROM test_table") # Should only return up to the limit assert len(results) == 2 @pytest.mark.asyncio async def test_list_databases_functionality(self, mock_athena_client): """Test database listing functionality""" with patch.object(mock_athena_client, '_get_client') as mock_get_client: mock_athena = mock_get_client.return_value mock_athena.list_databases.return_value = { "DatabaseList": [ {"Name": "test_security_lake_db"}, {"Name": "default"}, {"Name": "other_database"} ] } databases = await mock_athena_client.list_databases() assert len(databases) == 3 assert "test_security_lake_db" in databases assert "default" in databases assert "other_database" in databases @pytest.mark.asyncio async def test_list_tables_functionality(self, mock_athena_client): """Test table listing functionality""" with patch.object(mock_athena_client, '_get_client') as mock_get_client: mock_athena = mock_get_client.return_value mock_athena.list_table_metadata.return_value = { "TableMetadataList": [ { "Name": "amazon_security_lake_table_us_east_1_cloudtrail_mgmt_1_0", "TableType": "EXTERNAL_TABLE", "Columns": [ {"Name": "time", "Type": "timestamp"}, {"Name": "src_ip", "Type": "string"} ], "Parameters": { "location": "s3://aws-security-data-lake-us-east-1-123456789012/" }, "LastAccessTime": "2024-01-15T10:00:00Z" } ] } tables = await mock_athena_client.list_tables("test_security_lake_db") assert len(tables) == 1 table = tables[0] assert table["name"] == "amazon_security_lake_table_us_east_1_cloudtrail_mgmt_1_0" assert table["type"] == "EXTERNAL_TABLE" assert table["columns"] == 2 assert "s3://" in table["location"] @pytest.mark.asyncio async def test_credentials_error_handling(self, integration_settings): """Test handling of AWS credentials errors""" client = AthenaClient(integration_settings) # Mock credentials error with patch.object(client, '_get_client') as mock_get_client: from botocore.exceptions import NoCredentialsError mock_get_client.side_effect = NoCredentialsError() with pytest.raises(RuntimeError, match="AWS credentials not configured"): await client.execute_query("SELECT 1") @pytest.mark.asyncio async def test_aws_service_error_handling(self, mock_athena_client): """Test handling of AWS service errors""" with patch.object(mock_athena_client, '_get_client') as mock_get_client: from botocore.exceptions import ClientError error_response = { "Error": { "Code": "InvalidRequestException", "Message": "Query syntax error" } } mock_athena = mock_get_client.return_value mock_athena.start_query_execution.side_effect = ClientError( error_response, "StartQueryExecution" ) with pytest.raises(RuntimeError, match="AWS error \\(InvalidRequestException\\)"): await mock_athena_client.execute_query("INVALID SQL") @pytest.mark.asyncio async def test_connection_test_functionality(self, mock_athena_client): """Test connection test functionality""" with patch.object(mock_athena_client, 'list_databases') as mock_list_db: # Successful connection test mock_list_db.return_value = ["test_db"] result = await mock_athena_client.test_connection() assert result is True # Failed connection test mock_list_db.side_effect = Exception("Connection failed") result = await mock_athena_client.test_connection() assert result is False

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/kebabmane/asl-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server