import json
from unittest.mock import patch, MagicMock
import pytest
from hiveserver2_mcp.main import create_app
@pytest.fixture
def client():
app = create_app()
app.config["TESTING"] = True
with app.test_client() as client:
yield client
def test_query_success(client):
"""Test a successful query."""
with patch("hiveserver2_mcp.main.get_hive_connection") as mock_get_connection:
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = [("value1", 1), ("value2", 2)]
mock_cursor.description = [("column1",), ("column2",)]
mock_connection = MagicMock()
mock_connection.cursor.return_value = mock_cursor
mock_get_connection.return_value = mock_connection
response = client.post("/query", json={"query": "SELECT * FROM my_table"})
data = json.loads(response.data)
assert response.status_code == 200
assert data["columns"] == ["column1", "column2"]
assert data["rows"] == [["value1", 1], ["value2", 2]]
def test_query_missing_query(client):
"""Test a request with a missing query."""
response = client.post("/query", json={})
data = json.loads(response.data)
assert response.status_code == 400
assert data["error"] == "Missing query"
def test_query_db_error(client):
"""Test a database error."""
with patch("hiveserver2_mcp.main.get_hive_connection") as mock_get_connection:
mock_get_connection.side_effect = Exception("Database error")
response = client.post("/query", json={"query": "SELECT * FROM my_table"})
data = json.loads(response.data)
assert response.status_code == 500
assert data["error"] == "Database error"
def test_list_tables_success(client):
"""Test a successful list_tables request."""
with patch("hiveserver2_mcp.main.get_hive_connection") as mock_get_connection:
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = [("table1",), ("table2",)]
mock_connection = MagicMock()
mock_connection.cursor.return_value = mock_cursor
mock_get_connection.return_value = mock_connection
response = client.get("/list_tables")
data = json.loads(response.data)
assert response.status_code == 200
assert data["tables"] == ["table1", "table2"]
def test_describe_table_success(client):
"""Test a successful describe_table request."""
with patch("hiveserver2_mcp.main.get_hive_connection") as mock_get_connection:
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = [("col1", "string"), ("col2", "int")]
mock_connection = MagicMock()
mock_connection.cursor.return_value = mock_cursor
mock_get_connection.return_value = mock_connection
response = client.post("/describe_table", json={"table": "my_table"})
data = json.loads(response.data)
assert response.status_code == 200
assert data["schema"] == [{"name": "col1", "type": "string"}, {"name": "col2", "type": "int"}]
def test_list_databases_success(client):
"""Test a successful list_databases request."""
with patch("hiveserver2_mcp.main.get_hive_connection") as mock_get_connection:
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = [("db1",), ("db2",)]
mock_connection = MagicMock()
mock_connection.cursor.return_value = mock_cursor
mock_get_connection.return_value = mock_connection
response = client.get("/list_databases")
data = json.loads(response.data)
assert response.status_code == 200
assert data["databases"] == ["db1", "db2"]
def test_use_database_success(client):
"""Test a successful use_database request."""
with patch("hiveserver2_mcp.main.get_hive_connection") as mock_get_connection:
mock_connection = MagicMock()
mock_get_connection.return_value = mock_connection
response = client.post("/use_database", json={"database": "my_db"})
data = json.loads(response.data)
assert response.status_code == 200
assert data["message"] == "Switched to database: my_db"
def test_status_success(client):
"""Test a successful status request."""
with patch("hiveserver2_mcp.main.get_hive_connection") as mock_get_connection:
mock_get_connection.return_value = MagicMock()
response = client.get("/status")
data = json.loads(response.data)
assert response.status_code == 200
assert data["status"] == "ok"