import os
from typing import Dict, List
from unittest.mock import MagicMock
import pytest
from moto import mock_athena, mock_glue, mock_s3
from src.asl_mcp_server.config.settings import Settings
@pytest.fixture
def mock_settings():
"""Create mock settings for testing"""
return Settings(
aws_region="us-east-1",
security_lake_database="test_security_lake_db",
athena_workgroup="test-workgroup",
athena_output_location="s3://test-athena-results/",
log_level="DEBUG",
max_query_results=100,
query_timeout_seconds=60,
enable_query_caching=False
)
@pytest.fixture
def sample_ip_search_results():
"""Sample results for IP search testing"""
return [
{
"time": "2024-01-15T10:30:00Z",
"type_name": "Network Activity",
"severity": "Medium",
"activity_name": "Traffic",
"src_ip": "192.168.1.100",
"dst_ip": "203.0.113.45",
"src_port": "3456",
"dst_port": "443",
"product_name": "VPC Flow Logs",
"vendor_name": "AWS",
"account_id": "123456789012",
"region": "us-east-1"
},
{
"time": "2024-01-15T10:25:00Z",
"type_name": "API Activity",
"severity": "Low",
"activity_name": "Connect",
"src_ip": "192.168.1.100",
"dst_ip": "10.0.1.50",
"src_port": "54321",
"dst_port": "22",
"product_name": "CloudTrail",
"vendor_name": "AWS",
"account_id": "123456789012",
"region": "us-east-1"
}
]
@pytest.fixture
def sample_guardduty_results():
"""Sample results for GuardDuty testing"""
return [
{
"time": "2024-01-15T11:00:00Z",
"finding_id": "12345abc-def0-1234-5678-90abcdef1234",
"finding_title": "Suspicious network activity",
"finding_description": "Instance i-1234567890abcdef0 is communicating with a known malicious IP",
"severity": "High",
"severity_id": "4",
"type_name": "Trojan:EC2/MaliciousIP",
"activity_name": "Detection",
"finding_types": '["Trojan", "Network"]',
"resources": '[{"uid": "i-1234567890abcdef0", "type": "EC2Instance"}]',
"remediation": '{"recommendation": "Block traffic to malicious IP"}',
"src_ip": "10.0.1.100",
"dst_ip": "198.51.100.42",
"account_id": "123456789012",
"region": "us-east-1",
"product_name": "GuardDuty",
"product_version": "1.0"
}
]
@pytest.fixture
def sample_data_sources():
"""Sample data sources for testing"""
return [
{
"Name": "amazon_security_lake_table_us_east_1_cloudtrail_mgmt_1_0",
"TableType": "EXTERNAL_TABLE",
"Columns": [
{"Name": "time", "Type": "timestamp"},
{"Name": "type_name", "Type": "string"},
{"Name": "src_endpoint", "Type": "struct<ip:string,port:int>"}
],
"Location": "s3://aws-security-data-lake-us-east-1-123456789012/AWSLogs/123456789012/",
"LastAccessTime": "2024-01-15T10:00:00Z"
}
]
@pytest.fixture
def mock_athena_client():
"""Mock Athena client for testing"""
client = MagicMock()
# Mock successful query execution
client.execute_query.return_value = []
client.list_databases.return_value = ["test_security_lake_db", "default"]
client.list_tables.return_value = []
client.test_connection.return_value = True
return client
@pytest.fixture
def mock_query_builder():
"""Mock query builder for testing"""
builder = MagicMock()
builder.build_ip_search_query.return_value = (
"SELECT * FROM test_table WHERE src_ip = ? OR dst_ip = ?",
["192.168.1.1", "192.168.1.1"]
)
builder.build_guardduty_search_query.return_value = (
"SELECT * FROM test_table WHERE finding_id = ?",
["test-finding-id"]
)
builder.validate_ip_address.return_value = True
builder.sanitize_query_parameter.side_effect = lambda x: x
return builder
@pytest.fixture
def sample_ocsf_event():
"""Sample OCSF event for testing"""
return {
"time": "2024-01-15T10:30:00Z",
"type_name": "Network Activity",
"type_uid": 4001,
"class_name": "Network Activity",
"class_uid": 4001,
"category_name": "Network Activity",
"category_uid": 4,
"activity_name": "Traffic",
"activity_id": 6,
"severity": "Medium",
"severity_id": 3,
"metadata": {
"version": "1.1.0",
"product": {
"name": "VPC Flow Logs",
"vendor_name": "AWS"
}
},
"cloud": {
"account": {"uid": "123456789012"},
"region": "us-east-1"
},
"src_endpoint": {
"ip": "192.168.1.100",
"port": 3456
},
"dst_endpoint": {
"ip": "203.0.113.45",
"port": 443
}
}
@pytest.fixture
def invalid_ocsf_event():
"""Invalid OCSF event for testing validation"""
return {
"type_name": "Network Activity",
# Missing required fields: time, type_uid, class_name, class_uid, metadata
"severity": "Invalid", # Invalid severity
"src_endpoint": {
"ip": "invalid-ip", # Invalid IP format
"port": 99999 # Invalid port number
}
}
@pytest.fixture(scope="function")
def aws_credentials():
"""Mocked AWS Credentials for moto"""
os.environ["AWS_ACCESS_KEY_ID"] = "testing"
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
os.environ["AWS_SECURITY_TOKEN"] = "testing"
os.environ["AWS_SESSION_TOKEN"] = "testing"
@pytest.fixture
def mock_athena_service(aws_credentials):
"""Mock AWS Athena service using moto"""
with mock_athena():
yield
@pytest.fixture
def mock_glue_service(aws_credentials):
"""Mock AWS Glue service using moto"""
with mock_glue():
yield
@pytest.fixture
def mock_s3_service(aws_credentials):
"""Mock AWS S3 service using moto"""
with mock_s3():
yield
@pytest.fixture
def mock_aws_services(mock_athena_service, mock_glue_service, mock_s3_service):
"""Combined mock for all AWS services"""
yield
@pytest.fixture
def sample_athena_query_results():
"""Sample raw Athena query results"""
return [
{
"Data": [
{"VarCharValue": "time"},
{"VarCharValue": "src_ip"},
{"VarCharValue": "dst_ip"},
{"VarCharValue": "severity"}
]
}, # Header row
{
"Data": [
{"VarCharValue": "2024-01-15T10:30:00Z"},
{"VarCharValue": "192.168.1.100"},
{"VarCharValue": "203.0.113.45"},
{"VarCharValue": "Medium"}
]
}
]
@pytest.fixture
def environment_variables():
"""Set up environment variables for testing"""
env_vars = {
"ASL_MCP_AWS_REGION": "us-east-1",
"ASL_MCP_SECURITY_LAKE_DATABASE": "test_db",
"ASL_MCP_ATHENA_OUTPUT_LOCATION": "s3://test-bucket/results/",
"ASL_MCP_LOG_LEVEL": "DEBUG"
}
# Set environment variables
for key, value in env_vars.items():
os.environ[key] = value
yield env_vars
# Clean up environment variables
for key in env_vars.keys():
if key in os.environ:
del os.environ[key]