"""Unit tests for ExecuteSelectService and query validation (executor)."""
import pytest
from unittest.mock import Mock, patch
from mcp_sql.credentials import Credentials
from mcp_sql.executor import QueryExecutor
from mcp_sql.tools.execute_select import ExecuteSelectService
@pytest.mark.asyncio
class TestExecuteSelectService:
"""Test cases for ExecuteSelectService."""
async def test_name_property(self, tool_dependencies):
"""Test that the tool has the correct name."""
service = ExecuteSelectService(**tool_dependencies)
assert service.name == "execute_select_query"
async def test_description_property(self, tool_dependencies):
"""Test that the tool has a description."""
service = ExecuteSelectService(**tool_dependencies)
assert len(service.description) > 0
assert "SELECT" in service.description
assert "query" in service.description.lower()
async def test_execute_returns_query_results(self, tool_dependencies, mock_context):
"""Test that execute returns query results."""
service = ExecuteSelectService(**tool_dependencies)
result = await service.execute(
mock_context,
query="SELECT * FROM test_table",
database="test_db",
server_name="test_server",
limit=100
)
assert isinstance(result, dict)
assert "columns" in result
assert "rows" in result
assert "row_count" in result
assert "execution_time" in result
assert len(result["columns"]) == 3
assert len(result["rows"]) == 2
async def test_execute_with_custom_limit(self, tool_dependencies, mock_context):
"""Test execute with custom limit."""
service = ExecuteSelectService(**tool_dependencies)
result = await service.execute(
mock_context,
query="SELECT id, name FROM test_table",
database="test_db",
server_name="test_server",
limit=50
)
assert isinstance(result, dict)
tool_dependencies["executor"].execute_select_query.assert_called_once()
# Check that limit was passed correctly
call_args = tool_dependencies["executor"].execute_select_query.call_args
assert call_args[0][2] == 50 # Third argument is limit
async def test_execute_without_database(self, tool_dependencies, mock_context, mock_credentials):
"""Test execute without database name."""
mock_credentials.database = None
service = ExecuteSelectService(**tool_dependencies)
result = await service.execute(
mock_context,
query="SELECT * FROM test_table",
server_name="test_server"
)
assert isinstance(result, dict)
assert "error" in result
assert "Database name is required" in result["error"]
async def test_execute_with_empty_query(self, tool_dependencies, mock_context):
"""Test execute with empty query."""
service = ExecuteSelectService(**tool_dependencies)
result = await service.execute(
mock_context,
query="",
database="test_db",
server_name="test_server"
)
assert isinstance(result, dict)
assert "error" in result
assert "Query cannot be empty" in result["error"]
async def test_execute_with_whitespace_only_query(self, tool_dependencies, mock_context):
"""Test execute with whitespace-only query."""
service = ExecuteSelectService(**tool_dependencies)
result = await service.execute(
mock_context,
query=" \n\t ",
database="test_db",
server_name="test_server"
)
assert isinstance(result, dict)
assert "error" in result
assert "Query cannot be empty" in result["error"]
async def test_execute_with_invalid_credentials(self, tool_dependencies_invalid, mock_context):
"""Test execute with invalid credentials."""
service = ExecuteSelectService(**tool_dependencies_invalid)
result = await service.execute(
mock_context,
query="SELECT * FROM test_table",
database="test_db"
)
assert isinstance(result, dict)
assert "error" in result
assert "Missing credentials" in result["error"]
async def test_execute_calls_executor(self, tool_dependencies, mock_context):
"""Test that execute calls the query executor."""
service = ExecuteSelectService(**tool_dependencies)
await service.execute(
mock_context,
query="SELECT * FROM test_table",
database="test_db",
server_name="test_server"
)
tool_dependencies["executor"].execute_select_query.assert_called_once()
async def test_execute_with_all_optional_params(self, tool_dependencies, mock_context):
"""Test execute with all optional parameters."""
service = ExecuteSelectService(**tool_dependencies)
result = await service.execute(
mock_context,
query="SELECT * FROM test_table",
database="custom_db",
server_name="custom_server",
limit=200,
user="custom_user",
password="custom_pass",
driver="Custom Driver",
port=3306
)
assert isinstance(result, dict)
tool_dependencies["credentials_manager"].get_from_context.assert_called_once()
class TestExecuteSelectQueryValidation:
"""Test query validation in QueryExecutor (WITH/CTE and excluded servers)."""
def test_query_starting_with_with_cte_accepted(self):
"""Queries starting with WITH (CTE) are accepted for all servers."""
mock_conn = Mock()
mock_conn.get_engine_with_credentials = Mock(return_value=None)
executor = QueryExecutor(mock_conn)
creds = Credentials(
user="u", password="p", server="any_server", database="db"
)
result = executor.execute_select_query(
creds,
"WITH cte AS (SELECT 1 AS x) SELECT * FROM cte",
limit=10,
)
# Should not get "must start with SELECT" error; we get connection error
assert "error" in result
assert "must start with SELECT" not in result["error"]
assert "Could not create connection" in result["error"]
@patch.dict("os.environ", {"MCP_SQL_SERVERS_EXCLUDED_FROM_KEYWORD_CHECK": "trusted_host"}, clear=False)
def test_excluded_server_allows_dangerous_keyword(self):
"""When server is in exclusion list, dangerous keywords are not rejected."""
mock_conn = Mock()
mock_conn.get_engine_with_credentials = Mock(return_value=None)
executor = QueryExecutor(mock_conn)
creds = Credentials(
user="u", password="p", server="trusted_host", database="db"
)
# Query contains CREATE but server is excluded; must start with SELECT/WITH
result = executor.execute_select_query(
creds,
"SELECT 1 AS create_col FROM (SELECT 1) t",
limit=10,
)
assert "error" in result
assert "forbidden keyword" not in result["error"].lower()
assert "Could not create connection" in result["error"]
def test_non_excluded_server_rejects_dangerous_keyword(self):
"""When server is not in exclusion list, dangerous keyword returns error."""
mock_conn = Mock()
mock_conn.get_engine_with_credentials = Mock(return_value=None)
executor = QueryExecutor(mock_conn)
creds = Credentials(
user="u", password="p", server="other_host", database="db"
)
with patch.dict("os.environ", {"MCP_SQL_SERVERS_EXCLUDED_FROM_KEYWORD_CHECK": "trusted_host"}, clear=False):
result = executor.execute_select_query(
creds,
"SELECT 1 AS create_col FROM (SELECT 1) t",
limit=10,
)
assert "error" in result
assert "forbidden keyword" in result["error"].lower() or "CREATE" in result["error"]