Skip to main content
Glama
pydantic

mcp-run-python

Official
by pydantic
test_embeddings.py23.7 kB
from __future__ import annotations import os from collections.abc import Iterator from decimal import Decimal from typing import Any, get_args from unittest.mock import patch import pytest from inline_snapshot import snapshot from pydantic_ai.embeddings import ( Embedder, EmbeddingResult, EmbeddingSettings, InstrumentedEmbeddingModel, KnownEmbeddingModelName, TestEmbeddingModel, infer_embedding_model, ) from pydantic_ai.exceptions import ModelHTTPError, UserError from pydantic_ai.models.instrumented import InstrumentationSettings from pydantic_ai.usage import RequestUsage from .conftest import IsDatetime, IsFloat, IsInt, IsList, IsStr, try_import pytestmark = [ pytest.mark.anyio, ] with try_import() as logfire_imports_successful: from logfire.testing import CaptureLogfire with try_import() as openai_imports_successful: from pydantic_ai.embeddings.openai import LatestOpenAIEmbeddingModelNames, OpenAIEmbeddingModel from pydantic_ai.providers.gateway import GATEWAY_BASE_URL from pydantic_ai.providers.openai import OpenAIProvider with try_import() as cohere_imports_successful: from pydantic_ai.embeddings.cohere import CohereEmbeddingModel, LatestCohereEmbeddingModelNames from pydantic_ai.providers.cohere import CohereProvider with try_import() as sentence_transformers_imports_successful: from sentence_transformers import SentenceTransformer from pydantic_ai.embeddings.sentence_transformers import SentenceTransformerEmbeddingModel @pytest.mark.skipif(not openai_imports_successful(), reason='OpenAI not installed') @pytest.mark.vcr class TestOpenAI: @pytest.fixture def embedder(self, openai_api_key: str) -> Embedder: return Embedder(OpenAIEmbeddingModel('text-embedding-3-small', provider=OpenAIProvider(api_key=openai_api_key))) async def test_infer_model(self, openai_api_key: str): with patch.dict(os.environ, {'OPENAI_API_KEY': openai_api_key}): model = infer_embedding_model('openai:text-embedding-3-small') assert isinstance(model, OpenAIEmbeddingModel) assert model.model_name == 'text-embedding-3-small' assert model.system == 'openai' assert model.base_url == 'https://api.openai.com/v1/' async def test_infer_model_azure(self): with patch.dict( os.environ, { 'AZURE_OPENAI_API_KEY': 'azure-openai-api-key', 'AZURE_OPENAI_ENDPOINT': 'https://project-id.openai.azure.com/', 'OPENAI_API_VERSION': '2023-03-15-preview', }, ): model = infer_embedding_model('azure:text-embedding-3-small') assert isinstance(model, OpenAIEmbeddingModel) assert model.model_name == 'text-embedding-3-small' assert model.system == 'azure' assert 'azure.com' in model.base_url assert await model.max_input_tokens() is None with pytest.raises(UserError, match='Counting tokens is not supported for non-OpenAI embedding models'): await model.count_tokens('Hello, world!') async def test_infer_model_gateway(self): with patch.dict( os.environ, {'PYDANTIC_AI_GATEWAY_API_KEY': 'test-api-key', 'PYDANTIC_AI_GATEWAY_BASE_URL': GATEWAY_BASE_URL}, ): model = infer_embedding_model('gateway/openai:text-embedding-3-small') assert isinstance(model, OpenAIEmbeddingModel) assert model.model_name == 'text-embedding-3-small' assert model.system == 'openai' assert 'gateway.pydantic.dev' in model.base_url async def test_query(self, embedder: Embedder): result = await embedder.embed_query('Hello, world!') assert result == snapshot( EmbeddingResult( embeddings=IsList(IsList(IsFloat(), length=1536), length=1), inputs=['Hello, world!'], input_type='query', usage=RequestUsage(input_tokens=4), model_name='text-embedding-3-small', timestamp=IsDatetime(), provider_name='openai', ) ) assert result.cost().total_price == snapshot(Decimal('8E-8')) async def test_documents(self, embedder: Embedder): result = await embedder.embed_documents(['hello', 'world']) assert result == snapshot( EmbeddingResult( embeddings=IsList(IsList(IsFloat(), length=1536), length=2), inputs=['hello', 'world'], input_type='document', usage=RequestUsage(input_tokens=2), model_name='text-embedding-3-small', timestamp=IsDatetime(), provider_name='openai', ) ) assert result.cost().total_price == snapshot(Decimal('4E-8')) async def test_max_input_tokens(self, embedder: Embedder): max_input_tokens = await embedder.max_input_tokens() assert max_input_tokens == snapshot(8192) async def test_count_tokens(self, embedder: Embedder): count = await embedder.count_tokens('Hello, world!') assert count == snapshot(4) async def test_embed_error(self, openai_api_key: str): model = OpenAIEmbeddingModel('nonexistent', provider=OpenAIProvider(api_key=openai_api_key)) embedder = Embedder(model) with pytest.raises(ModelHTTPError, match='model_not_found'): await embedder.embed_query('Hello, world!') @pytest.mark.skipif(not logfire_imports_successful(), reason='logfire not installed') async def test_instrumentation(self, openai_api_key: str, capfire: CaptureLogfire): model = OpenAIEmbeddingModel('text-embedding-3-small', provider=OpenAIProvider(api_key=openai_api_key)) embedder = Embedder(model, instrument=True) await embedder.embed_query('Hello, world!', settings={'dimensions': 128}) spans = capfire.exporter.exported_spans_as_dict(parse_json_attributes=True) span = next(span for span in spans if 'embeddings' in span['name']) assert span == snapshot( { 'name': 'embeddings text-embedding-3-small', 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, 'parent': None, 'start_time': IsInt(), 'end_time': IsInt(), 'attributes': { 'gen_ai.operation.name': 'embeddings', 'gen_ai.provider.name': 'openai', 'gen_ai.request.model': 'text-embedding-3-small', 'input_type': 'query', 'server.address': 'api.openai.com', 'inputs_count': 1, 'embedding_settings': {'dimensions': 128}, 'inputs': ['Hello, world!'], 'logfire.json_schema': { 'type': 'object', 'properties': { 'input_type': {'type': 'string'}, 'inputs_count': {'type': 'integer'}, 'embedding_settings': {'type': 'object'}, 'inputs': {'type': ['array']}, 'embeddings': {'type': 'array'}, }, }, 'logfire.span_type': 'span', 'logfire.msg': 'embeddings text-embedding-3-small', 'gen_ai.usage.input_tokens': 4, 'operation.cost': 8e-08, 'gen_ai.response.model': 'text-embedding-3-small', 'gen_ai.embeddings.dimension.count': 128, }, } ) assert capfire.get_collected_metrics() == snapshot( [ { 'name': 'gen_ai.client.token.usage', 'description': 'Measures number of input and output tokens used', 'unit': '{token}', 'data': { 'data_points': [ { 'attributes': { 'gen_ai.provider.name': 'openai', 'gen_ai.operation.name': 'embeddings', 'gen_ai.request.model': 'text-embedding-3-small', 'gen_ai.response.model': 'text-embedding-3-small', 'gen_ai.token.type': 'input', }, 'start_time_unix_nano': IsInt(), 'time_unix_nano': IsInt(), 'count': 1, 'sum': 4, 'scale': 20, 'zero_count': 0, 'positive': {'offset': 2097151, 'bucket_counts': [1]}, 'negative': {'offset': 0, 'bucket_counts': [0]}, 'flags': 0, 'min': 4, 'max': 4, 'exemplars': [], } ], 'aggregation_temporality': 1, }, }, { 'name': 'operation.cost', 'description': 'Monetary cost', 'unit': '{USD}', 'data': { 'data_points': [ { 'attributes': { 'gen_ai.provider.name': 'openai', 'gen_ai.operation.name': 'embeddings', 'gen_ai.request.model': 'text-embedding-3-small', 'gen_ai.response.model': 'text-embedding-3-small', 'gen_ai.token.type': 'input', }, 'start_time_unix_nano': IsInt(), 'time_unix_nano': IsInt(), 'count': 1, 'sum': 8e-08, 'scale': 20, 'zero_count': 0, 'positive': {'offset': -24720625, 'bucket_counts': [1]}, 'negative': {'offset': 0, 'bucket_counts': [0]}, 'flags': 0, 'min': 8e-08, 'max': 8e-08, 'exemplars': [], } ], 'aggregation_temporality': 1, }, }, ] ) @pytest.mark.skipif(not cohere_imports_successful(), reason='Cohere not installed') @pytest.mark.vcr class TestCohere: async def test_infer_model(self, co_api_key: str): with patch.dict(os.environ, {'CO_API_KEY': co_api_key}): model = infer_embedding_model('cohere:embed-v4.0') assert isinstance(model, CohereEmbeddingModel) assert model.model_name == 'embed-v4.0' assert model.system == 'cohere' assert model.base_url == 'https://api.cohere.com' assert isinstance(model._provider, CohereProvider) # type: ignore[reportAttributeAccess] async def test_query(self, co_api_key: str): model = CohereEmbeddingModel('embed-v4.0', provider=CohereProvider(api_key=co_api_key)) embedder = Embedder(model) result = await embedder.embed_query('Hello, world!') assert result == snapshot( EmbeddingResult( embeddings=IsList( IsList(snapshot(-0.018445116), snapshot(0.008921167), snapshot(-0.0011377502), length=1536), length=1, ), inputs=['Hello, world!'], input_type='query', usage=RequestUsage(input_tokens=4), model_name='embed-v4.0', timestamp=IsDatetime(), provider_name='cohere', provider_response_id='0728b136-9b30-4fb5-bf9a-2c7cf36d51d3', ) ) assert result.cost().total_price == snapshot(Decimal('4.8E-7')) async def test_documents(self, co_api_key: str): model = CohereEmbeddingModel('embed-v4.0', provider=CohereProvider(api_key=co_api_key)) embedder = Embedder(model) result = await embedder.embed_documents(['hello', 'world']) assert result == snapshot( EmbeddingResult( embeddings=IsList(IsList(IsFloat(), length=1536), length=2), inputs=['hello', 'world'], input_type='document', usage=RequestUsage(input_tokens=2), model_name='embed-v4.0', timestamp=IsDatetime(), provider_name='cohere', provider_response_id='199299d7-f43d-45af-903c-347fff81bbe4', ) ) assert result.cost().total_price == snapshot(Decimal('2.4E-7')) async def test_max_input_tokens(self, co_api_key: str): model = CohereEmbeddingModel('embed-v4.0', provider=CohereProvider(api_key=co_api_key)) embedder = Embedder(model) max_input_tokens = await embedder.max_input_tokens() assert max_input_tokens == snapshot(128000) async def test_count_tokens(self, co_api_key: str): model = CohereEmbeddingModel('embed-v4.0', provider=CohereProvider(api_key=co_api_key)) embedder = Embedder(model) count = await embedder.count_tokens('Hello, world!') assert count == snapshot(4) async def test_embed_error(self, co_api_key: str): model = CohereEmbeddingModel('nonexistent', provider=CohereProvider(api_key=co_api_key)) embedder = Embedder(model) with pytest.raises(ModelHTTPError, match='not found,'): await embedder.embed_query('Hello, world!') @pytest.mark.skipif(not sentence_transformers_imports_successful(), reason='SentenceTransformers not installed') class TestSentenceTransformers: @pytest.fixture(scope='session') def stsb_bert_tiny_model(self): model = SentenceTransformer('sentence-transformers-testing/stsb-bert-tiny-safetensors') model.model_card_data.generate_widget_examples = False # Disable widget examples generation for testing return model @pytest.fixture def embedder(self, stsb_bert_tiny_model: Any) -> Embedder: return Embedder(SentenceTransformerEmbeddingModel(stsb_bert_tiny_model)) async def test_infer_model(self): model = infer_embedding_model('sentence-transformers:all-MiniLM-L6-v2') assert isinstance(model, SentenceTransformerEmbeddingModel) assert model.model_name == 'all-MiniLM-L6-v2' assert model.system == 'sentence-transformers' assert model.base_url is None async def test_query(self, embedder: Embedder): result = await embedder.embed_query('Hello, world!') assert result == snapshot( EmbeddingResult( embeddings=IsList(IsList(IsFloat(), length=128), length=1), inputs=['Hello, world!'], input_type='query', model_name='sentence-transformers-testing/stsb-bert-tiny-safetensors', timestamp=IsDatetime(), provider_name='sentence-transformers', ) ) async def test_documents(self, embedder: Embedder): result = await embedder.embed_documents(['hello', 'world']) assert result == snapshot( EmbeddingResult( embeddings=IsList(IsList(IsFloat(), length=128), length=2), inputs=['hello', 'world'], input_type='document', model_name='sentence-transformers-testing/stsb-bert-tiny-safetensors', timestamp=IsDatetime(), provider_name='sentence-transformers', ) ) async def test_max_input_tokens(self, embedder: Embedder): max_input_tokens = await embedder.max_input_tokens() assert max_input_tokens == snapshot(512) async def test_count_tokens(self, embedder: Embedder): count = await embedder.count_tokens('Hello, world!') assert count == snapshot(6) @pytest.mark.skipif( not openai_imports_successful() or not cohere_imports_successful(), reason='some embedding package was not installed', ) def test_known_embedding_model_names(): # pragma: lax no cover # Coverage seems to be misbehaving..? def get_model_names(model_name_type: Any) -> Iterator[str]: for arg in get_args(model_name_type): if isinstance(arg, str): yield arg else: yield from get_model_names(arg) openai_names = [f'openai:{n}' for n in get_model_names(LatestOpenAIEmbeddingModelNames)] cohere_names = [f'cohere:{n}' for n in get_model_names(LatestCohereEmbeddingModelNames)] generated_names = sorted(openai_names + cohere_names) known_model_names = sorted(get_args(KnownEmbeddingModelName.__value__)) if generated_names != known_model_names: errors: list[str] = [] missing_names = set(generated_names) - set(known_model_names) if missing_names: errors.append(f'Missing names: {missing_names}') extra_names = set(known_model_names) - set(generated_names) if extra_names: errors.append(f'Extra names: {extra_names}') raise AssertionError('\n'.join(errors)) def test_infer_model_error(): with pytest.raises(ValueError, match='You must provide a provider prefix when specifying an embedding model name'): infer_embedding_model('nonexistent') async def test_instrument_all(): model = TestEmbeddingModel() embedder = Embedder(model) def get_model(): return embedder._get_model() # pyright: ignore[reportPrivateUsage] Embedder.instrument_all(False) assert get_model() is model Embedder.instrument_all() m = get_model() assert isinstance(m, InstrumentedEmbeddingModel) assert m.wrapped is model assert m.instrumentation_settings.event_mode == InstrumentationSettings().event_mode assert m.model_name == model.model_name assert m.system == model.system assert m.base_url == model.base_url assert m.settings == model.settings assert (await m.embed('Hello, world!', input_type='query')).embeddings == ( await model.embed('Hello, world!', input_type='query') ).embeddings assert await m.max_input_tokens() == await model.max_input_tokens() assert await m.count_tokens('Hello, world!') == await model.count_tokens('Hello, world!') options = InstrumentationSettings(version=1, event_mode='logs') Embedder.instrument_all(options) m = get_model() assert isinstance(m, InstrumentedEmbeddingModel) assert m.wrapped is model assert m.instrumentation_settings is options Embedder.instrument_all(False) assert get_model() is model def test_override(): model = TestEmbeddingModel() embedder = Embedder(model) model2 = TestEmbeddingModel() with embedder.override(model=model2): assert embedder._get_model() is model2 # pyright: ignore[reportPrivateUsage] with embedder.override(): assert embedder._get_model() is model # pyright: ignore[reportPrivateUsage] assert embedder._get_model() is model # pyright: ignore[reportPrivateUsage] def test_sync(): model = TestEmbeddingModel() embedder = Embedder(model) result = embedder.embed_query_sync('Hello, world!') assert isinstance(result, EmbeddingResult) result = embedder.embed_documents_sync(['hello', 'world']) assert isinstance(result, EmbeddingResult) result = embedder.embed_sync('Hello, world!', input_type='query') assert isinstance(result, EmbeddingResult) result = embedder.max_input_tokens_sync() assert isinstance(result, int) result = embedder.count_tokens_sync('Hello, world!') assert isinstance(result, int) async def test_settings(): model_settings: EmbeddingSettings = {'dimensions': 128, 'from_model': True} # pyright: ignore[reportAssignmentType] model = TestEmbeddingModel(settings=model_settings) assert model.settings == model_settings await Embedder(model).embed_query('Hello, world!') assert model.last_settings == snapshot({'dimensions': 128, 'from_model': True}) embedder_settings: EmbeddingSettings = {'dimensions': 256, 'from_embedder': True} # pyright: ignore[reportAssignmentType] embedder = Embedder(model, settings=embedder_settings) await embedder.embed_query('Hello, world!') assert model.last_settings == snapshot({'dimensions': 256, 'from_model': True, 'from_embedder': True}) embed_settings: EmbeddingSettings = {'dimensions': 512, 'from_embed': True} # pyright: ignore[reportAssignmentType] await embedder.embed_query('Hello, world!', settings=embed_settings) assert model.last_settings == snapshot( {'dimensions': 512, 'from_model': True, 'from_embedder': True, 'from_embed': True} ) def test_result(): result = EmbeddingResult( embeddings=[[-1.0], [-0.5], [0.0], [0.5], [1.0]], inputs=['a', 'b', 'c', 'd', 'e'], input_type='document', model_name='test', timestamp=IsDatetime(), provider_name='test', ) assert result[0] == result['a'] == snapshot([-1.0]) assert result[1] == result['b'] == snapshot([-0.5]) assert result[2] == result['c'] == snapshot([0.0]) assert result[3] == result['d'] == snapshot([0.5]) assert result[4] == result['e'] == snapshot([1.0]) @pytest.mark.skipif(not logfire_imports_successful(), reason='logfire not installed') async def test_limited_instrumentation(capfire: CaptureLogfire): model = TestEmbeddingModel() embedder = Embedder(model, instrument=InstrumentationSettings(include_content=False)) await embedder.embed_query('Hello, world!') assert capfire.exporter.exported_spans_as_dict(parse_json_attributes=True) == snapshot( [ { 'name': 'embeddings test', 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, 'parent': None, 'start_time': IsInt(), 'end_time': IsInt(), 'attributes': { 'gen_ai.operation.name': 'embeddings', 'gen_ai.provider.name': 'test', 'gen_ai.request.model': 'test', 'input_type': 'query', 'inputs_count': 1, 'logfire.json_schema': { 'type': 'object', 'properties': { 'input_type': {'type': 'string'}, 'inputs_count': {'type': 'integer'}, 'embedding_settings': {'type': 'object'}, }, }, 'logfire.span_type': 'span', 'logfire.msg': 'embeddings test', 'gen_ai.usage.input_tokens': 2, 'gen_ai.response.model': 'test', 'gen_ai.embeddings.dimension.count': 8, 'gen_ai.response.id': IsStr(), }, } ] )

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/pydantic/pydantic-ai'

If you have feedback or need assistance with the MCP directory API, please join our Discord server