test_chroma_ops.py•14.1 kB
"""Tests for low-level ChromaDB operations and utilities."""
import os
import pytest
from unittest.mock import patch, MagicMock, ANY, AsyncMock
import platform
import chromadb
from chromadb.api.client import ClientAPI
from chromadb.config import Settings
# Import the module containing the function under test
from chroma_mcp.utils import chroma_client as client_module
from mcp.shared.exceptions import McpError
from mcp.types import ErrorData, INTERNAL_ERROR, INVALID_PARAMS
# Import ChromaClientConfig
from chroma_mcp.utils.chroma_client import (
ChromaClientConfig,
get_chroma_client,
get_embedding_function,
reset_client,
)
from src.chroma_mcp.utils.errors import (
ValidationError,
EmbeddingError,
ClientError,
ConfigurationError,
)
from src.chroma_mcp.utils.config import ServerConfig, load_config, get_collection_settings, validate_collection_name
from chroma_mcp.server import config_server # Needed to set up the client
@pytest.fixture(autouse=True)
def mock_logger():
"""Auto-mock the logger used within chroma_ops test module utilities."""
# Patch the logger in the module where it's used (e.g., chroma_client)
with patch("chroma_mcp.utils.chroma_client.get_logger") as mock_get_logger:
mock_log_instance = MagicMock()
mock_get_logger.return_value = mock_log_instance
yield mock_log_instance
# Client Tests
class TestChromaClient:
"""Test cases for ChromaDB client operations."""
def setup_method(self, method):
"""Reset client before each test."""
reset_client()
# Also reset embedding function state if necessary
client_module._embedding_function = None
def teardown_method(self, method):
"""Ensure client is reset after each test."""
reset_client()
client_module._embedding_function = None
def test_get_ephemeral_client(self):
"""Test getting an ephemeral client."""
with patch("chromadb.EphemeralClient") as mock_ephemeral_client:
mock_client = MagicMock(spec=ClientAPI)
mock_ephemeral_client.return_value = mock_client
config = ChromaClientConfig(client_type="ephemeral")
client = get_chroma_client(config)
assert isinstance(client, ClientAPI)
mock_ephemeral_client.assert_called_once_with(settings=ANY)
def test_get_persistent_client(self, tmp_path):
"""Test getting a persistent client."""
with patch("chromadb.PersistentClient") as mock_persistent_client:
mock_client = MagicMock(spec=ClientAPI)
mock_persistent_client.return_value = mock_client
data_dir = str(tmp_path / "chroma_data")
config = ChromaClientConfig(client_type="persistent", data_dir=data_dir)
client = get_chroma_client(config)
assert isinstance(client, ClientAPI)
mock_persistent_client.assert_called_once_with(path=data_dir, settings=ANY)
def test_get_http_client(self):
"""Test getting an HTTP client."""
with patch("chromadb.HttpClient") as mock_http_client:
mock_client = MagicMock(spec=ClientAPI)
mock_http_client.return_value = mock_client
config = ChromaClientConfig(client_type="http", host="localhost", port="8000", ssl=True)
client = get_chroma_client(config)
assert isinstance(client, ClientAPI)
mock_http_client.assert_called_once_with(
host="localhost", port="8000", ssl=True, tenant=None, database=None, settings=ANY
)
def test_get_persistent_client_without_data_dir(self):
"""Test error when getting persistent client without data_dir."""
config = ChromaClientConfig(client_type="persistent")
with pytest.raises(ValueError) as exc_info:
get_chroma_client(config)
assert "data_dir is required for persistent client" in str(exc_info.value)
def test_get_http_client_without_host(self):
"""Test error when getting HTTP client without host."""
config = ChromaClientConfig(client_type="http")
with pytest.raises(ValueError) as exc_info:
get_chroma_client(config)
assert "host is required for http client" in str(exc_info.value)
def test_reset_client(self):
"""Test resetting the client."""
with patch("chroma_mcp.utils.chroma_client._chroma_client") as mock_client:
reset_client()
mock_client.reset.assert_called_once()
def test_get_persistent_client_success(self, tmp_path):
"""Test successful execution of get_chroma_client for PersistentClient."""
# No patch needed, just call the function with valid config
data_dir = str(tmp_path / "chroma_success")
config_dict = {"client_type": "PersistentClient", "data_dir": data_dir}
config_obj = ChromaClientConfig(**config_dict)
client = get_chroma_client(config_obj)
# Assert a ClientAPI instance is returned
assert isinstance(client, ClientAPI)
def test_get_http_client_success(self, mocker):
"""Test successful execution of get_chroma_client for HttpClient."""
# Use mocker to prevent actual HTTP connection if HttpClient tries
mocker.patch("chromadb.HttpClient") # Simple patch to avoid real connection
config_dict = {
"client_type": "HttpClient",
"host": "127.0.0.1", # Use valid-looking host/port
"port": 8001,
"ssl": False,
"tenant": "test_tenant",
"database": "test_db",
"api_key": "dummy_key",
}
config_obj = ChromaClientConfig(**config_dict)
client = get_chroma_client(config_obj)
# Assert a ClientAPI instance is returned
assert isinstance(client, ClientAPI)
def test_get_persistent_client_init_fails(self, mocker):
"""Test handling when underlying PersistentClient initialization fails.
This now patches get_chroma_client itself.
"""
# Simulate the error that would occur inside get_chroma_client
expected_error = McpError(
ErrorData(code=INTERNAL_ERROR, message="Failed to initialize ChromaDB client: DB connection failed")
)
mock_get_client = mocker.patch("chroma_mcp.utils.chroma_client.get_chroma_client", side_effect=expected_error)
config_dict = {"client_type": "PersistentClient", "data_dir": "/tmp/nonexistent"}
config_obj = ChromaClientConfig(**config_dict)
with pytest.raises(McpError) as exc_info:
# Call the function (which is now mocked)
# We don't need to call the *real* get_chroma_client here
# Instead, we might test code *calling* get_chroma_client
# For this test, let's just call the mock to trigger the side effect
mock_get_client(config_obj)
assert "Failed to initialize ChromaDB client" in str(exc_info.value)
assert "DB connection failed" in str(exc_info.value)
mock_get_client.assert_called_once_with(config_obj)
def test_get_http_client_init_fails(self, mocker):
"""Test handling when underlying HttpClient initialization fails.
This now patches get_chroma_client itself.
"""
expected_error = McpError(
ErrorData(code=INTERNAL_ERROR, message="Failed to initialize ChromaDB client: HTTP connection refused")
)
mock_get_client = mocker.patch("chroma_mcp.utils.chroma_client.get_chroma_client", side_effect=expected_error)
config_dict = {"client_type": "HttpClient", "host": "localhost", "port": 8001, "ssl": False}
config_obj = ChromaClientConfig(**config_dict)
with pytest.raises(McpError) as exc_info:
# Call the mock to trigger the side effect
mock_get_client(config_obj)
assert "Failed to initialize ChromaDB client" in str(exc_info.value)
assert "HTTP connection refused" in str(exc_info.value)
mock_get_client.assert_called_once_with(config_obj)
def test_reset_client_no_existing_client(self, mock_logger):
"""Test reset_client when no client is cached."""
# Ensure the global client is None initially
client_module._chroma_client = None
reset_client()
# Assert logger indicates no client was reset
mock_logger.info.assert_any_call("No active Chroma client instance to reset.")
def test_reset_client_underlying_error(self, mock_logger):
"""Test reset_client handling errors from the client's reset method."""
mock_client_instance = MagicMock()
error_message = "Underlying reset failed"
mock_client_instance.reset.side_effect = Exception(error_message)
# Set the global client to the mock
client_module._chroma_client = mock_client_instance
# Call reset_client - it should catch the error and log it
reset_client()
# Assert the client's reset was called
mock_client_instance.reset.assert_called_once()
# Assert the error was logged
mock_logger.error.assert_called_once_with(f"Error resetting client: {error_message}")
# Assert the global client was still set to None
assert client_module._chroma_client is None
def test_reset_client_allow_reset_false(self, mock_logger):
"""Test reset_client handling the specific 'Resetting is not allowed' error."""
mock_client_instance = MagicMock()
error_message = "Resetting is not allowed for this client type."
mock_client_instance.reset.side_effect = Exception(error_message)
# Set the global client to the mock
client_module._chroma_client = mock_client_instance
# Call reset_client
reset_client()
# Assert the client's reset was called
mock_client_instance.reset.assert_called_once()
# Assert the specific warning was logged
mock_logger.warning.assert_called_once_with(
f"Client reset failed gracefully (allow_reset=False): {error_message}"
)
# Assert the error logger was NOT called
mock_logger.error.assert_not_called()
# Assert the global client was still set to None
assert client_module._chroma_client is None
# Error Handling Tests
class TestErrorHandling:
"""Test cases for error handling utilities."""
def test_handle_chroma_error_collection_not_found(self):
"""Test handling collection not found error."""
# This test is no longer relevant as handle_chroma_error was removed
# error = Exception("Collection not found")
# result = handle_chroma_error(error, "test_operation")
# assert isinstance(result, McpError)
# assert "Collection not found" in result.error.message
pass # Remove or adapt if similar logic exists elsewhere
def test_handle_chroma_error_client_error(self):
"""Test handling client error."""
# This test is no longer relevant as handle_chroma_error was removed
# error = Exception("connection failed")
# result = handle_chroma_error(error, "test_operation")
# assert isinstance(result, McpError)
# assert "connection failed" in result.error.message
pass # Remove or adapt
def test_raise_validation_error(self):
"""Test raising validation error."""
# This test is no longer relevant as raise_validation_error was removed
pass # Remove or adapt
# Configuration Tests
class TestConfiguration:
"""Test cases for configuration utilities."""
def test_load_config_defaults(self):
"""Test loading default configuration."""
config = load_config()
assert isinstance(config, ServerConfig)
assert config.log_level == "INFO"
assert config.max_batch_size == 100
assert config.enable_telemetry is False
def test_load_config_from_env(self, monkeypatch):
"""Test loading configuration from environment variables."""
monkeypatch.setenv("CHROMA_LOG_LEVEL", "DEBUG")
monkeypatch.setenv("CHROMA_MAX_BATCH_SIZE", "200")
monkeypatch.setenv("CHROMA_ENABLE_TELEMETRY", "true")
config = load_config()
assert config.log_level == "DEBUG"
assert config.max_batch_size == 200
assert config.enable_telemetry is True
def test_get_collection_settings_default(self):
"""Test getting default collection settings."""
settings = get_collection_settings()
assert isinstance(settings, dict)
assert "hnsw:space" in settings
def test_get_collection_settings_custom(self):
"""Test getting custom collection settings."""
settings = get_collection_settings(collection_name="test", hnsw_space="cosine", hnsw_construction_ef=100)
assert settings["hnsw:space"] == "cosine"
assert settings["hnsw:construction_ef"] == 100
def test_validate_collection_name_success(self):
"""Test successful collection name validation."""
validate_collection_name("valid_collection-123")
# Should not raise any exception
def test_validate_collection_name_empty(self):
"""Test validation of empty collection name."""
with pytest.raises(McpError) as exc_info:
validate_collection_name("")
assert "Collection name cannot be empty" in str(exc_info.value)
def test_validate_collection_name_too_long(self):
"""Test validation of too long collection name."""
with pytest.raises(McpError) as exc_info:
validate_collection_name("a" * 65)
assert "Collection name cannot be longer than 64 characters" in str(exc_info.value)
def test_validate_collection_name_invalid_chars(self):
"""Test validation of collection name with invalid characters."""
with pytest.raises(McpError) as exc_info:
validate_collection_name("invalid@collection")
assert "Collection name can only contain letters, numbers, underscores, and hyphens" in str(exc_info.value)