test_gemini.py•15.1 kB
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Running tests: pytest -xvs tests/embedder/test_gemini.py
from collections.abc import Generator
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from embedder_fixtures import create_embedding_values
from graphiti_core.embedder.gemini import (
DEFAULT_EMBEDDING_MODEL,
GeminiEmbedder,
GeminiEmbedderConfig,
)
def create_gemini_embedding(multiplier: float = 0.1, dimension: int = 1536) -> MagicMock:
"""Create a mock Gemini embedding with specified value multiplier and dimension."""
mock_embedding = MagicMock()
mock_embedding.values = create_embedding_values(multiplier, dimension)
return mock_embedding
@pytest.fixture
def mock_gemini_response() -> MagicMock:
"""Create a mock Gemini embeddings response."""
mock_result = MagicMock()
mock_result.embeddings = [create_gemini_embedding()]
return mock_result
@pytest.fixture
def mock_gemini_batch_response() -> MagicMock:
"""Create a mock Gemini batch embeddings response."""
mock_result = MagicMock()
mock_result.embeddings = [
create_gemini_embedding(0.1),
create_gemini_embedding(0.2),
create_gemini_embedding(0.3),
]
return mock_result
@pytest.fixture
def mock_gemini_client() -> Generator[Any, Any, None]:
"""Create a mocked Gemini client."""
with patch('google.genai.Client') as mock_client:
mock_instance = mock_client.return_value
mock_instance.aio = MagicMock()
mock_instance.aio.models = MagicMock()
mock_instance.aio.models.embed_content = AsyncMock()
yield mock_instance
@pytest.fixture
def gemini_embedder(mock_gemini_client: Any) -> GeminiEmbedder:
"""Create a GeminiEmbedder with a mocked client."""
config = GeminiEmbedderConfig(api_key='test_api_key')
client = GeminiEmbedder(config=config)
client.client = mock_gemini_client
return client
class TestGeminiEmbedderInitialization:
"""Tests for GeminiEmbedder initialization."""
@patch('google.genai.Client')
def test_init_with_config(self, mock_client):
"""Test initialization with a config object."""
config = GeminiEmbedderConfig(
api_key='test_api_key', embedding_model='custom-model', embedding_dim=768
)
embedder = GeminiEmbedder(config=config)
assert embedder.config == config
assert embedder.config.embedding_model == 'custom-model'
assert embedder.config.api_key == 'test_api_key'
assert embedder.config.embedding_dim == 768
@patch('google.genai.Client')
def test_init_without_config(self, mock_client):
"""Test initialization without a config uses defaults."""
embedder = GeminiEmbedder()
assert embedder.config is not None
assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL
@patch('google.genai.Client')
def test_init_with_partial_config(self, mock_client):
"""Test initialization with partial config."""
config = GeminiEmbedderConfig(api_key='test_api_key')
embedder = GeminiEmbedder(config=config)
assert embedder.config.api_key == 'test_api_key'
assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL
class TestGeminiEmbedderCreate:
"""Tests for GeminiEmbedder create method."""
@pytest.mark.asyncio
async def test_create_calls_api_correctly(
self,
gemini_embedder: GeminiEmbedder,
mock_gemini_client: Any,
mock_gemini_response: MagicMock,
) -> None:
"""Test that create method correctly calls the API and processes the response."""
# Setup
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
# Call method
result = await gemini_embedder.create('Test input')
# Verify API is called with correct parameters
mock_gemini_client.aio.models.embed_content.assert_called_once()
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
assert kwargs['contents'] == ['Test input']
# Verify result is processed correctly
assert result == mock_gemini_response.embeddings[0].values
@pytest.mark.asyncio
@patch('google.genai.Client')
async def test_create_with_custom_model(
self, mock_client_class, mock_gemini_client: Any, mock_gemini_response: MagicMock
) -> None:
"""Test create method with custom embedding model."""
# Setup embedder with custom model
config = GeminiEmbedderConfig(api_key='test_api_key', embedding_model='custom-model')
embedder = GeminiEmbedder(config=config)
embedder.client = mock_gemini_client
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
# Call method
await embedder.create('Test input')
# Verify custom model is used
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['model'] == 'custom-model'
@pytest.mark.asyncio
@patch('google.genai.Client')
async def test_create_with_custom_dimension(
self, mock_client_class, mock_gemini_client: Any
) -> None:
"""Test create method with custom embedding dimension."""
# Setup embedder with custom dimension
config = GeminiEmbedderConfig(api_key='test_api_key', embedding_dim=768)
embedder = GeminiEmbedder(config=config)
embedder.client = mock_gemini_client
# Setup mock response with custom dimension
mock_response = MagicMock()
mock_response.embeddings = [create_gemini_embedding(0.1, 768)]
mock_gemini_client.aio.models.embed_content.return_value = mock_response
# Call method
result = await embedder.create('Test input')
# Verify custom dimension is used in config
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['config'].output_dimensionality == 768
# Verify result has correct dimension
assert len(result) == 768
@pytest.mark.asyncio
async def test_create_with_different_input_types(
self,
gemini_embedder: GeminiEmbedder,
mock_gemini_client: Any,
mock_gemini_response: MagicMock,
) -> None:
"""Test create method with different input types."""
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
# Test with string
await gemini_embedder.create('Test string')
# Test with list of strings
await gemini_embedder.create(['Test', 'List'])
# Test with iterable of integers
await gemini_embedder.create([1, 2, 3])
# Verify all calls were made
assert mock_gemini_client.aio.models.embed_content.call_count == 3
@pytest.mark.asyncio
async def test_create_no_embeddings_error(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create method handling of no embeddings response."""
# Setup mock response with no embeddings
mock_response = MagicMock()
mock_response.embeddings = []
mock_gemini_client.aio.models.embed_content.return_value = mock_response
# Call method and expect exception
with pytest.raises(ValueError) as exc_info:
await gemini_embedder.create('Test input')
assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)
@pytest.mark.asyncio
async def test_create_no_values_error(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create method handling of embeddings with no values."""
# Setup mock response with embedding but no values
mock_embedding = MagicMock()
mock_embedding.values = None
mock_response = MagicMock()
mock_response.embeddings = [mock_embedding]
mock_gemini_client.aio.models.embed_content.return_value = mock_response
# Call method and expect exception
with pytest.raises(ValueError) as exc_info:
await gemini_embedder.create('Test input')
assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)
class TestGeminiEmbedderCreateBatch:
"""Tests for GeminiEmbedder create_batch method."""
@pytest.mark.asyncio
async def test_create_batch_processes_multiple_inputs(
self,
gemini_embedder: GeminiEmbedder,
mock_gemini_client: Any,
mock_gemini_batch_response: MagicMock,
) -> None:
"""Test that create_batch method correctly processes multiple inputs."""
# Setup
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_batch_response
input_batch = ['Input 1', 'Input 2', 'Input 3']
# Call method
result = await gemini_embedder.create_batch(input_batch)
# Verify API is called with correct parameters
mock_gemini_client.aio.models.embed_content.assert_called_once()
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
assert kwargs['contents'] == input_batch
# Verify all results are processed correctly
assert len(result) == 3
assert result == [
mock_gemini_batch_response.embeddings[0].values,
mock_gemini_batch_response.embeddings[1].values,
mock_gemini_batch_response.embeddings[2].values,
]
@pytest.mark.asyncio
async def test_create_batch_single_input(
self,
gemini_embedder: GeminiEmbedder,
mock_gemini_client: Any,
mock_gemini_response: MagicMock,
) -> None:
"""Test create_batch method with single input."""
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
input_batch = ['Single input']
result = await gemini_embedder.create_batch(input_batch)
assert len(result) == 1
assert result[0] == mock_gemini_response.embeddings[0].values
@pytest.mark.asyncio
async def test_create_batch_empty_input(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create_batch method with empty input."""
# Setup mock response with no embeddings
mock_response = MagicMock()
mock_response.embeddings = []
mock_gemini_client.aio.models.embed_content.return_value = mock_response
input_batch = []
result = await gemini_embedder.create_batch(input_batch)
assert result == []
mock_gemini_client.aio.models.embed_content.assert_not_called()
@pytest.mark.asyncio
async def test_create_batch_no_embeddings_error(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create_batch method handling of no embeddings response."""
# Setup mock response with no embeddings
mock_response = MagicMock()
mock_response.embeddings = []
mock_gemini_client.aio.models.embed_content.return_value = mock_response
input_batch = ['Input 1', 'Input 2']
with pytest.raises(ValueError) as exc_info:
await gemini_embedder.create_batch(input_batch)
assert 'No embeddings returned from Gemini API' in str(exc_info.value)
@pytest.mark.asyncio
async def test_create_batch_empty_values_error(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create_batch method handling of embeddings with empty values."""
# Setup mock response with embeddings but empty values
mock_embedding1 = MagicMock()
mock_embedding1.values = [0.1, 0.2, 0.3] # Valid values
mock_embedding2 = MagicMock()
mock_embedding2.values = None # Empty values
# Mock response for the initial batch call
mock_batch_response = MagicMock()
mock_batch_response.embeddings = [mock_embedding1, mock_embedding2]
# Mock response for individual processing of 'Input 1'
mock_individual_response_1 = MagicMock()
mock_individual_response_1.embeddings = [mock_embedding1]
# Mock response for individual processing of 'Input 2' (which has empty values)
mock_individual_response_2 = MagicMock()
mock_individual_response_2.embeddings = [mock_embedding2]
# Set side_effect for embed_content to control return values for each call
mock_gemini_client.aio.models.embed_content.side_effect = [
mock_batch_response, # First call for the batch
mock_individual_response_1, # Second call for individual item 1
mock_individual_response_2, # Third call for individual item 2
]
input_batch = ['Input 1', 'Input 2']
with pytest.raises(ValueError) as exc_info:
await gemini_embedder.create_batch(input_batch)
assert 'Empty embedding values returned' in str(exc_info.value)
@pytest.mark.asyncio
@patch('google.genai.Client')
async def test_create_batch_with_custom_model_and_dimension(
self, mock_client_class, mock_gemini_client: Any
) -> None:
"""Test create_batch method with custom model and dimension."""
# Setup embedder with custom settings
config = GeminiEmbedderConfig(
api_key='test_api_key', embedding_model='custom-batch-model', embedding_dim=512
)
embedder = GeminiEmbedder(config=config)
embedder.client = mock_gemini_client
# Setup mock response
mock_response = MagicMock()
mock_response.embeddings = [
create_gemini_embedding(0.1, 512),
create_gemini_embedding(0.2, 512),
]
mock_gemini_client.aio.models.embed_content.return_value = mock_response
input_batch = ['Input 1', 'Input 2']
result = await embedder.create_batch(input_batch)
# Verify custom settings are used
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['model'] == 'custom-batch-model'
assert kwargs['config'].output_dimensionality == 512
# Verify results have correct dimension
assert len(result) == 2
assert all(len(embedding) == 512 for embedding in result)
if __name__ == '__main__':
pytest.main(['-xvs', __file__])