"""Unit tests for FabricSQLService."""
import struct
from unittest.mock import Mock, patch
import pytest
from ms_fabric_mcp_server.client.exceptions import FabricConnectionError, FabricError
from ms_fabric_mcp_server.models.item import FabricItem
from ms_fabric_mcp_server.models.results import QueryResult
@pytest.fixture
def sql_module():
import ms_fabric_mcp_server.services.sql as sql_module
return sql_module
@pytest.fixture
def sql_service(mock_fabric_client, sql_module, monkeypatch):
"""Create SQL service with patched pyodbc and DefaultAzureCredential."""
monkeypatch.setattr(sql_module, "PYODBC_AVAILABLE", True)
monkeypatch.setattr(sql_module, "OTEL_AVAILABLE", False)
pyodbc_mock = Mock()
monkeypatch.setattr(sql_module, "pyodbc", pyodbc_mock)
azure_cred_mock = Mock()
monkeypatch.setattr(sql_module, "DefaultAzureCredential", azure_cred_mock)
workspace_service = Mock()
item_service = Mock()
service = sql_module.FabricSQLService(
mock_fabric_client,
workspace_service,
item_service,
)
return service, workspace_service, item_service, pyodbc_mock, azure_cred_mock
@pytest.mark.unit
class TestFabricSQLService:
"""Test suite for FabricSQLService."""
def test_init_missing_pyodbc_raises(self, sql_module, mock_fabric_client, monkeypatch):
"""Missing pyodbc raises ImportError."""
monkeypatch.setattr(sql_module, "PYODBC_AVAILABLE", False)
with pytest.raises(ImportError):
sql_module.FabricSQLService(mock_fabric_client, Mock(), Mock())
def test_get_sql_endpoint_warehouse(self, sql_service, mock_fabric_client):
"""Warehouse SQL endpoint uses properties.connectionString."""
service, workspace_service, item_service, _, _ = sql_service
workspace_service.resolve_workspace_id.return_value = "ws-1"
item_service.get_item_by_name.return_value = FabricItem(
id="wh-1",
display_name="Warehouse",
type="Warehouse",
workspace_id="ws-1",
)
response = Mock()
response.json.return_value = {"properties": {"connectionString": "wh-endpoint"}}
mock_fabric_client.make_api_request.return_value = response
endpoint = service.get_sql_endpoint("Workspace", "Warehouse", "Warehouse")
assert endpoint == "wh-endpoint"
mock_fabric_client.make_api_request.assert_called_once_with(
"GET", "workspaces/ws-1/warehouses/wh-1"
)
def test_get_sql_endpoint_lakehouse(self, sql_service, mock_fabric_client):
"""Lakehouse SQL endpoint uses sqlEndpointProperties.connectionString."""
service, workspace_service, item_service, _, _ = sql_service
workspace_service.resolve_workspace_id.return_value = "ws-1"
item_service.get_item_by_name.return_value = FabricItem(
id="lh-1",
display_name="Lakehouse",
type="Lakehouse",
workspace_id="ws-1",
)
response = Mock()
response.json.return_value = {
"properties": {"sqlEndpointProperties": {"connectionString": "lh-endpoint"}}
}
mock_fabric_client.make_api_request.return_value = response
endpoint = service.get_sql_endpoint("Workspace", "Lakehouse", "Lakehouse")
assert endpoint == "lh-endpoint"
mock_fabric_client.make_api_request.assert_called_once_with(
"GET", "workspaces/ws-1/lakehouses/lh-1"
)
def test_get_sql_endpoint_missing_connection_string(self, sql_service, mock_fabric_client):
"""Missing connection string raises FabricError."""
service, workspace_service, item_service, _, _ = sql_service
workspace_service.resolve_workspace_id.return_value = "ws-1"
item_service.get_item_by_name.return_value = FabricItem(
id="wh-1",
display_name="Warehouse",
type="Warehouse",
workspace_id="ws-1",
)
response = Mock()
response.json.return_value = {"properties": {}}
mock_fabric_client.make_api_request.return_value = response
with pytest.raises(FabricError):
service.get_sql_endpoint("Workspace", "Warehouse", "Warehouse")
def test_get_sql_endpoint_invalid_item_type(self, sql_service):
"""Invalid item type raises ValueError."""
service, *_ = sql_service
with pytest.raises(ValueError):
service.get_sql_endpoint("Workspace", "Item", "Invalid")
def test_get_token_bytes_success(self, sql_service):
"""Token bytes include length prefix."""
service, _, _, _, azure_cred_mock = sql_service
token = Mock()
token.token = "abc"
azure_cred_instance = Mock()
azure_cred_instance.get_token.return_value = token
azure_cred_mock.return_value = azure_cred_instance
token_bytes = service._get_token_bytes()
azure_cred_instance.get_token.assert_called_once_with("https://database.windows.net/.default")
length = struct.unpack("<i", token_bytes[:4])[0]
assert length == len(token_bytes[4:])
def test_get_token_bytes_failure(self, sql_service):
"""Token failures raise FabricConnectionError."""
service, _, _, _, azure_cred_mock = sql_service
azure_cred_mock.side_effect = RuntimeError("boom")
with pytest.raises(FabricConnectionError):
service._get_token_bytes()
def test_connect_adds_port_and_calls_pyodbc(self, sql_service):
"""Connect adds default port and uses token bytes."""
service, _, _, pyodbc_mock, _ = sql_service
service._get_token_bytes = Mock(return_value=b"token")
pyodbc_mock.connect.return_value = Mock()
service.connect("server-host")
args, kwargs = pyodbc_mock.connect.call_args
assert "Server=server-host,1433" in args[0]
assert kwargs["attrs_before"][1256] == b"token"
assert kwargs["autocommit"] is True
assert service._connection is not None
def test_connect_failure(self, sql_service):
"""Connection failures raise FabricConnectionError."""
service, _, _, pyodbc_mock, _ = sql_service
service._get_token_bytes = Mock(return_value=b"token")
pyodbc_mock.connect.side_effect = RuntimeError("boom")
with pytest.raises(FabricConnectionError):
service.connect("server-host")
def test_execute_query_requires_connection(self, sql_service):
"""execute_query raises when not connected."""
service, *_ = sql_service
service._connection = None
with pytest.raises(FabricConnectionError):
service.execute_query("SELECT 1")
def test_execute_query_success(self, sql_service):
"""execute_query returns QueryResult with rows."""
service, *_ = sql_service
cursor = Mock()
cursor.description = [("id",), ("name",)]
cursor.fetchall.return_value = [(1, "a"), (2, "b")]
connection = Mock()
connection.cursor.return_value = cursor
service._connection = connection
result = service.execute_query("SELECT * FROM test")
assert result.status == "success"
assert result.row_count == 2
assert result.columns == ["id", "name"]
assert result.data == [{"id": 1, "name": "a"}, {"id": 2, "name": "b"}]
def test_execute_query_failure_returns_error(self, sql_service):
"""execute_query returns error result on failure."""
service, *_ = sql_service
cursor = Mock()
cursor.execute.side_effect = RuntimeError("boom")
connection = Mock()
connection.cursor.return_value = cursor
service._connection = connection
result = service.execute_query("SELECT * FROM test")
assert result.status == "error"
assert "Query execution failed" in result.message
def test_execute_query_retries_transient_error(self, sql_service):
"""execute_query retries once on transient ODBC errors."""
service, *_ = sql_service
cursor = Mock()
cursor.description = [("id",)]
cursor.fetchall.return_value = [(1,)]
cursor.execute.side_effect = [RuntimeError("HYT00: timeout"), None]
connection = Mock()
connection.cursor.return_value = cursor
service._connection = connection
service._sql_endpoint = None
service._database = None
with patch("ms_fabric_mcp_server.services.sql.time.sleep") as sleep_mock, patch(
"ms_fabric_mcp_server.services.sql.random.uniform", return_value=5.0
):
result = service.execute_query("SELECT 1")
assert result.status == "success"
assert cursor.execute.call_count == 2
sleep_mock.assert_called_once()
def test_execute_statement_requires_connection(self, sql_service):
"""execute_statement raises when not connected."""
service, *_ = sql_service
service._connection = None
with pytest.raises(FabricConnectionError):
service.execute_statement("UPDATE test SET a=1")
def test_execute_statement_success(self, sql_service):
"""execute_statement executes and returns success (autocommit, no manual commit)."""
service, *_ = sql_service
cursor = Mock()
cursor.rowcount = 3
connection = Mock()
connection.cursor.return_value = cursor
service._connection = connection
result = service.execute_statement("UPDATE test SET a=1")
assert result["status"] == "success"
assert result["affected_rows"] == 3
connection.commit.assert_not_called()
def test_execute_statement_rejects_non_dml(self, sql_service):
"""execute_statement returns error for non-DML statements."""
service, *_ = sql_service
connection = Mock()
connection.cursor.return_value = Mock()
service._connection = connection
result = service.execute_statement("SELECT 1")
assert result["status"] == "error"
assert result["affected_rows"] == 0
def test_execute_statement_rejects_ddl_without_flag(self, sql_service):
"""execute_statement rejects DDL when allow_ddl is False."""
service, *_ = sql_service
connection = Mock()
connection.cursor.return_value = Mock()
service._connection = connection
result = service.execute_statement("CREATE TABLE test (id int)")
assert result["status"] == "error"
assert result["affected_rows"] == 0
def test_execute_statement_allows_ddl_with_flag(self, sql_service):
"""execute_statement allows DDL when allow_ddl is True."""
service, *_ = sql_service
cursor = Mock()
cursor.rowcount = 0
connection = Mock()
connection.cursor.return_value = cursor
service._connection = connection
result = service.execute_statement("CREATE TABLE test (id int)", allow_ddl=True)
assert result["status"] == "success"
connection.commit.assert_not_called()
def test_execute_statement_failure(self, sql_service):
"""execute_statement returns error on failure (no rollback with autocommit)."""
service, *_ = sql_service
cursor = Mock()
cursor.execute.side_effect = RuntimeError("boom")
connection = Mock()
connection.cursor.return_value = cursor
service._connection = connection
result = service.execute_statement("UPDATE test SET a=1")
assert result["status"] == "error"
connection.rollback.assert_not_called()
def test_execute_sql_query_calls_close_on_error(self, sql_service):
"""execute_sql_query closes its own connection even on error."""
service, _, _, _, _ = sql_service
mock_conn = Mock()
cursor = Mock()
cursor.execute.side_effect = RuntimeError("boom")
mock_conn.cursor.return_value = cursor
service._create_connection = Mock(return_value=mock_conn)
result = service.execute_sql_query("endpoint", "SELECT 1")
assert result.status == "error"
mock_conn.close.assert_called_once()
def test_execute_sql_query_closes_on_success(self, sql_service):
"""execute_sql_query creates independent connection and closes it."""
service, _, _, _, _ = sql_service
cursor = Mock()
cursor.description = [("id",)]
cursor.fetchall.return_value = [(1,)]
mock_conn = Mock()
mock_conn.cursor.return_value = cursor
service._create_connection = Mock(return_value=mock_conn)
result = service.execute_sql_query("endpoint", "SELECT 1")
assert result.status == "success"
service._create_connection.assert_called_once_with("endpoint", "Metadata")
mock_conn.close.assert_called_once()
def test_execute_sql_statement_calls_close_on_error(self, sql_service):
"""execute_sql_statement closes its own connection even on error."""
service, _, _, _, _ = sql_service
mock_conn = Mock()
cursor = Mock()
cursor.execute.side_effect = RuntimeError("boom")
mock_conn.cursor.return_value = cursor
service._create_connection = Mock(return_value=mock_conn)
result = service.execute_sql_statement("endpoint", "UPDATE test SET a=1")
assert result["status"] == "error"
mock_conn.close.assert_called_once()
def test_execute_sql_statement_closes_on_success(self, sql_service):
"""execute_sql_statement creates independent connection and closes it."""
service, _, _, _, _ = sql_service
cursor = Mock()
cursor.rowcount = 1
mock_conn = Mock()
mock_conn.cursor.return_value = cursor
service._create_connection = Mock(return_value=mock_conn)
result = service.execute_sql_statement("endpoint", "UPDATE test SET a=1")
assert result["status"] == "success"
service._create_connection.assert_called_once_with("endpoint", "Metadata")
mock_conn.close.assert_called_once()
def test_get_tables_success(self, sql_service):
"""get_tables returns table list on success."""
service, *_ = sql_service
service.execute_query = Mock(
return_value=QueryResult(
status="success",
data=[{"TABLE_NAME": "a"}, {"TABLE_NAME": "b"}],
columns=["TABLE_NAME"],
row_count=2,
)
)
tables = service.get_tables("dbo")
assert tables == ["a", "b"]
def test_get_tables_error(self, sql_service):
"""get_tables returns empty list on error."""
service, *_ = sql_service
service.execute_query = Mock(return_value=QueryResult(status="error"))
tables = service.get_tables("dbo")
assert tables == []
def test_get_table_schema_success(self, sql_service):
"""get_table_schema returns schema details on success."""
service, *_ = sql_service
schema_data = [{"COLUMN_NAME": "id", "DATA_TYPE": "int"}]
service.execute_query = Mock(
return_value=QueryResult(status="success", data=schema_data)
)
result = service.get_table_schema("table", "dbo")
assert result == schema_data
def test_get_table_schema_error(self, sql_service):
"""get_table_schema returns empty list on error."""
service, *_ = sql_service
service.execute_query = Mock(return_value=QueryResult(status="error"))
result = service.get_table_schema("table", "dbo")
assert result == []
def test_is_connected_false_when_none(self, sql_service):
"""is_connected returns False when no connection."""
service, *_ = sql_service
service._connection = None
assert service.is_connected() is False
def test_is_connected_true(self, sql_service):
"""is_connected returns True on successful ping."""
service, *_ = sql_service
cursor = Mock()
cursor.fetchone.return_value = (1,)
connection = Mock()
connection.cursor.return_value = cursor
service._connection = connection
assert service.is_connected() is True
def test_is_connected_false_on_error(self, sql_service):
"""is_connected returns False on errors."""
service, *_ = sql_service
cursor = Mock()
cursor.execute.side_effect = RuntimeError("boom")
connection = Mock()
connection.cursor.return_value = cursor
service._connection = connection
assert service.is_connected() is False
def test_close_resets_connection(self, sql_service):
"""close resets connection even on close error."""
service, *_ = sql_service
connection = Mock()
connection.close.side_effect = RuntimeError("boom")
service._connection = connection
service.close()
assert service._connection is None
def test_execute_sql_query_uses_independent_connections(self, sql_service):
"""Concurrent execute_sql_query calls each get their own connection."""
service, _, _, _, _ = sql_service
conn1 = Mock()
conn2 = Mock()
cursor1 = Mock()
cursor1.description = [("id",)]
cursor1.fetchall.return_value = [(1,)]
cursor2 = Mock()
cursor2.description = [("id",)]
cursor2.fetchall.return_value = [(2,)]
conn1.cursor.return_value = cursor1
conn2.cursor.return_value = cursor2
# _create_connection returns different connections on each call
service._create_connection = Mock(side_effect=[conn1, conn2])
r1 = service.execute_sql_query("ep", "SELECT 1")
r2 = service.execute_sql_query("ep", "SELECT 2")
assert r1.status == "success"
assert r2.status == "success"
assert service._create_connection.call_count == 2
conn1.close.assert_called_once()
conn2.close.assert_called_once()
# Shared state should NOT have been touched
assert service._connection is None
def test_create_connection_returns_new_connection(self, sql_service):
"""_create_connection creates a fresh pyodbc.Connection."""
service, _, _, pyodbc_mock, _ = sql_service
service._get_token_bytes = Mock(return_value=b"token")
pyodbc_mock.connect.return_value = Mock()
conn = service._create_connection("server-host", "mydb")
assert conn is not None
args, kwargs = pyodbc_mock.connect.call_args
assert "Server=server-host,1433" in args[0]
assert "Database=mydb" in args[0]
# Should NOT mutate self._connection
assert service._connection is None