test_evaluators.py•21.9 kB
from __future__ import annotations as _annotations
from dataclasses import dataclass
from typing import Any, cast
import pytest
from inline_snapshot import snapshot
from pydantic import BaseModel, TypeAdapter
from pydantic_core import to_jsonable_python
from pydantic_ai import ModelMessage, ModelResponse
from pydantic_ai.models import Model, ModelRequestParameters
from pydantic_ai.settings import ModelSettings
from ..conftest import IsStr, try_import
with try_import() as imports_successful:
    import logfire
    from logfire.testing import CaptureLogfire
    from pydantic_evals.evaluators._run_evaluator import run_evaluator
    from pydantic_evals.evaluators.common import (
        Contains,
        Equals,
        EqualsExpected,
        HasMatchingSpan,
        IsInstance,
        LLMJudge,
        MaxDuration,
    )
    from pydantic_evals.evaluators.context import EvaluatorContext
    from pydantic_evals.evaluators.evaluator import (
        EvaluationReason,
        EvaluationResult,
        Evaluator,
        EvaluatorFailure,
        EvaluatorOutput,
    )
    from pydantic_evals.evaluators.spec import EvaluatorSpec
    from pydantic_evals.otel._context_in_memory_span_exporter import context_subtree
    from pydantic_evals.otel.span_tree import SpanQuery, SpanTree
pytestmark = [pytest.mark.skipif(not imports_successful(), reason='pydantic-evals not installed'), pytest.mark.anyio]
class TaskInput(BaseModel):
    query: str
class TaskOutput(BaseModel):
    answer: str
class TaskMetadata(BaseModel):
    difficulty: str = 'easy'
@pytest.fixture
def test_context() -> EvaluatorContext[TaskInput, TaskOutput, TaskMetadata]:
    return EvaluatorContext[TaskInput, TaskOutput, TaskMetadata](
        name='test_case',
        inputs=TaskInput(query='What is 2+2?'),
        output=TaskOutput(answer='4'),
        expected_output=TaskOutput(answer='4'),
        metadata=TaskMetadata(difficulty='easy'),
        duration=0.1,
        _span_tree=SpanTree(),
        attributes={},
        metrics={},
    )
async def test_evaluator_spec_initialization():
    """Test initializing EvaluatorSpec."""
    # Simple form with just a name
    spec1 = EvaluatorSpec(name='MyEvaluator', arguments=None)
    assert spec1.name == 'MyEvaluator'
    assert spec1.args == ()
    assert spec1.kwargs == {}
    # Form with args - using a tuple with a single element containing a tuple
    args_tuple = cast(tuple[Any], (('arg1', 'arg2'),))
    spec2 = EvaluatorSpec(name='MyEvaluator', arguments=args_tuple)
    assert spec2.name == 'MyEvaluator'
    assert len(spec2.args) == 1
    assert spec2.args[0] == ('arg1', 'arg2')
    assert spec2.kwargs == {}
    # Form with kwargs
    spec3 = EvaluatorSpec(name='MyEvaluator', arguments={'key1': 'value1', 'key2': 'value2'})
    assert spec3.name == 'MyEvaluator'
    assert spec3.args == ()
    assert spec3.kwargs == {'key1': 'value1', 'key2': 'value2'}
async def test_evaluator_spec_serialization():
    """Test serializing EvaluatorSpec."""
    # Create a spec
    spec = EvaluatorSpec(name='MyEvaluator', arguments={'key1': 'value1'})
    adapter = TypeAdapter(EvaluatorSpec)
    assert adapter.dump_python(spec) == snapshot({'name': 'MyEvaluator', 'arguments': {'key1': 'value1'}})
    assert adapter.dump_python(spec, context={'use_short_form': True}) == snapshot({'MyEvaluator': {'key1': 'value1'}})
    # Test string serialization
    spec_simple = EvaluatorSpec(name='MyEvaluator', arguments=None)
    assert adapter.dump_python(spec_simple) == snapshot({'name': 'MyEvaluator', 'arguments': None})
    assert adapter.dump_python(spec_simple, context={'use_short_form': True}) == snapshot('MyEvaluator')
    # Test single arg serialization
    single_arg = cast(tuple[Any], ('value1',))
    spec_single_arg = EvaluatorSpec(name='MyEvaluator', arguments=single_arg)
    assert adapter.dump_python(spec_single_arg) == snapshot({'name': 'MyEvaluator', 'arguments': ('value1',)})
    assert adapter.dump_python(spec_single_arg, context={'use_short_form': True}) == snapshot({'MyEvaluator': 'value1'})
async def test_llm_judge_serialization():
    # Ensure models are serialized based on their system + name when used with LLMJudge
    class MyModel(Model):
        async def request(
            self,
            messages: list[ModelMessage],
            model_settings: ModelSettings | None,
            model_request_parameters: ModelRequestParameters,
        ) -> ModelResponse:
            raise NotImplementedError
        @property
        def model_name(self) -> str:
            return 'my-model'
        @property
        def system(self) -> str:
            return 'my-system'
    adapter = TypeAdapter(Evaluator)
    assert adapter.dump_python(LLMJudge(rubric='my rubric', model=MyModel())) == {
        'name': 'LLMJudge',
        'arguments': {'model': 'my-system:my-model', 'rubric': 'my rubric'},
    }
async def test_evaluator_call(test_context: EvaluatorContext[TaskInput, TaskOutput, TaskMetadata]):
    """Test calling an Evaluator."""
    @dataclass
    class ExampleEvaluator(Evaluator[TaskInput, TaskOutput, TaskMetadata]):
        """A test evaluator for testing purposes."""
        def evaluate(self, ctx: EvaluatorContext[TaskInput, TaskOutput, TaskMetadata]) -> EvaluatorOutput:
            assert ctx.inputs.query == 'What is 2+2?'
            assert ctx.output.answer == '4'
            assert ctx.expected_output and ctx.expected_output.answer == '4'
            assert ctx.metadata and ctx.metadata.difficulty == 'easy'
            return {'result': 'passed'}
    evaluator = ExampleEvaluator()
    results = await run_evaluator(evaluator, test_context)
    assert not isinstance(results, EvaluatorFailure)
    assert len(results) == 1
    first_result = results[0]
    assert isinstance(first_result, EvaluationResult)
    assert first_result.name == 'result'
    assert first_result.value == 'passed'
    assert first_result.reason is None
    assert first_result.source == EvaluatorSpec(name='ExampleEvaluator', arguments=None)
async def test_is_instance_evaluator():
    """Test the IsInstance evaluator."""
    # Create a context with the correct object typing for IsInstance
    object_context = EvaluatorContext[object, object, object](
        name='test_case',
        inputs=TaskInput(query='What is 2+2?'),
        output=TaskOutput(answer='4'),
        expected_output=None,
        metadata=None,
        duration=0.1,
        _span_tree=SpanTree(),
        attributes={},
        metrics={},
    )
    # Test with matching types
    evaluator = IsInstance(type_name='TaskOutput')
    result = evaluator.evaluate(object_context)
    assert isinstance(result, EvaluationReason)
    assert result.value is True
    # Test with non-matching types
    class DifferentOutput(BaseModel):
        different_field: str
    # Create a context with DifferentOutput
    diff_context = EvaluatorContext[object, object, object](
        name='mismatch_case',
        inputs=TaskInput(query='What is 2+2?'),
        output=DifferentOutput(different_field='not an answer'),
        expected_output=None,
        metadata=None,
        duration=0.1,
        _span_tree=SpanTree(),
        attributes={},
        metrics={},
    )
    result = evaluator.evaluate(diff_context)
    assert isinstance(result, EvaluationReason)
    assert result.value is False
async def test_custom_evaluator(test_context: EvaluatorContext[TaskInput, TaskOutput, TaskMetadata]):
    """Test a custom evaluator."""
    @dataclass
    class CustomEvaluator(Evaluator[TaskInput, TaskOutput, TaskMetadata]):
        def evaluate(self, ctx: EvaluatorContext[TaskInput, TaskOutput, TaskMetadata]) -> EvaluatorOutput:
            # Check if the answer is correct based on expected output
            is_correct = ctx.output.answer == ctx.expected_output.answer if ctx.expected_output else False
            # Use metadata if available
            difficulty = ctx.metadata.difficulty if ctx.metadata else 'unknown'
            return {
                'is_correct': is_correct,
                'difficulty': difficulty,
            }
    evaluator = CustomEvaluator()
    result = evaluator.evaluate(test_context)
    assert result == snapshot({'difficulty': 'easy', 'is_correct': True})
async def test_custom_evaluator_name(test_context: EvaluatorContext[TaskInput, TaskOutput, TaskMetadata]):
    @dataclass
    class CustomNameFieldEvaluator(Evaluator[TaskInput, TaskOutput, TaskMetadata]):
        result: int
        evaluation_name: str
        def evaluate(self, ctx: EvaluatorContext[TaskInput, TaskOutput, TaskMetadata]) -> EvaluatorOutput:
            return self.result
    evaluator = CustomNameFieldEvaluator(result=123, evaluation_name='abc')
    assert to_jsonable_python(await run_evaluator(evaluator, test_context)) == snapshot(
        [
            {
                'name': 'abc',
                'reason': None,
                'source': {'arguments': {'evaluation_name': 'abc', 'result': 123}, 'name': 'CustomNameFieldEvaluator'},
                'value': 123,
            }
        ]
    )
    @dataclass
    class CustomNamePropertyEvaluator(Evaluator[TaskInput, TaskOutput, TaskMetadata]):
        result: int
        my_name: str
        @property
        def evaluation_name(self) -> str:
            return f'hello {self.my_name}'
        def evaluate(self, ctx: EvaluatorContext[TaskInput, TaskOutput, TaskMetadata]) -> EvaluatorOutput:
            return self.result
    evaluator = CustomNamePropertyEvaluator(result=123, my_name='marcelo')
    assert to_jsonable_python(await run_evaluator(evaluator, test_context)) == snapshot(
        [
            {
                'name': 'hello marcelo',
                'reason': None,
                'source': {'arguments': {'my_name': 'marcelo', 'result': 123}, 'name': 'CustomNamePropertyEvaluator'},
                'value': 123,
            }
        ]
    )
async def test_evaluator_error_handling(test_context: EvaluatorContext[TaskInput, TaskOutput, TaskMetadata]):
    """Test error handling in evaluators."""
    @dataclass
    class FailingEvaluator(Evaluator[TaskInput, TaskOutput, TaskMetadata]):
        def evaluate(self, ctx: EvaluatorContext[TaskInput, TaskOutput, TaskMetadata]) -> EvaluatorOutput:
            raise ValueError('Simulated error')
    evaluator = FailingEvaluator()
    # When called directly, it should raise an error
    result = await run_evaluator(evaluator, test_context)
    assert result == EvaluatorFailure(
        name='FailingEvaluator',
        error_message='ValueError: Simulated error',
        error_stacktrace=IsStr(),
        source=FailingEvaluator().as_spec(),
    )
async def test_evaluator_with_null_values():
    """Test evaluator with null expected_output and metadata."""
    @dataclass
    class NullValueEvaluator(Evaluator[TaskInput, TaskOutput, TaskMetadata]):
        def evaluate(self, ctx: EvaluatorContext[TaskInput, TaskOutput, TaskMetadata]) -> EvaluatorOutput:
            return {
                'has_expected_output': ctx.expected_output is not None,
                'has_metadata': ctx.metadata is not None,
            }
    evaluator = NullValueEvaluator()
    context = EvaluatorContext[TaskInput, TaskOutput, TaskMetadata](
        name=None,
        inputs=TaskInput(query='What is 2+2?'),
        output=TaskOutput(answer='4'),
        expected_output=None,
        metadata=None,
        duration=0.1,
        _span_tree=SpanTree(),
        attributes={},
        metrics={},
    )
    result = evaluator.evaluate(context)
    assert isinstance(result, dict)
    assert result['has_expected_output'] is False
    assert result['has_metadata'] is False
async def test_equals_evaluator(test_context: EvaluatorContext[TaskInput, TaskOutput, TaskMetadata]):
    """Test the equals evaluator."""
    # Test with matching value
    evaluator = Equals(value=TaskOutput(answer='4'))
    result = evaluator.evaluate(test_context)
    assert result is True
    # Test with non-matching value
    evaluator = Equals(value=TaskOutput(answer='5'))
    result = evaluator.evaluate(test_context)
    assert result is False
    # Test with completely different type
    evaluator = Equals(value='not a TaskOutput')
    result = evaluator.evaluate(test_context)
    assert result is False
async def test_equals_expected_evaluator(test_context: EvaluatorContext[TaskInput, TaskOutput, TaskMetadata]):
    """Test the equals_expected evaluator."""
    # Test with matching expected output (already set in test_context)
    evaluator = EqualsExpected()
    result = evaluator.evaluate(test_context)
    assert result is True
    # Test with non-matching expected output
    context_with_different_expected = EvaluatorContext[TaskInput, TaskOutput, TaskMetadata](
        name='test_case',
        inputs=TaskInput(query='What is 2+2?'),
        output=TaskOutput(answer='4'),
        expected_output=TaskOutput(answer='5'),  # Different expected output
        metadata=TaskMetadata(difficulty='easy'),
        duration=0.1,
        _span_tree=SpanTree(),
        attributes={},
        metrics={},
    )
    result = evaluator.evaluate(context_with_different_expected)
    assert result is False
    # Test with no expected output
    context_with_no_expected = EvaluatorContext[TaskInput, TaskOutput, TaskMetadata](
        name='test_case',
        inputs=TaskInput(query='What is 2+2?'),
        output=TaskOutput(answer='4'),
        expected_output=None,  # No expected output
        metadata=TaskMetadata(difficulty='easy'),
        duration=0.1,
        _span_tree=SpanTree(),
        attributes={},
        metrics={},
    )
    result = evaluator.evaluate(context_with_no_expected)
    assert result == {}  # Should return empty dict when no expected output
async def test_contains_evaluator():
    """Test the contains evaluator."""
    # Test with string output
    string_context = EvaluatorContext[object, str, object](
        name='string_test',
        inputs="What's in the box?",
        output='There is a cat in the box',
        expected_output=None,
        metadata=None,
        duration=0.1,
        _span_tree=SpanTree(),
        attributes={},
        metrics={},
    )
    # String contains - case sensitive
    evaluator = Contains(value='cat in the')
    assert evaluator.evaluate(string_context) == snapshot(EvaluationReason(value=True))
    # String doesn't contain
    evaluator = Contains(value='dog')
    assert evaluator.evaluate(string_context) == snapshot(
        EvaluationReason(
            value=False,
            reason="Output string 'There is a cat in the box' does not contain expected string 'dog'",
        )
    )
    # Very long strings don't get included in reason
    evaluator = Contains(value='a' * 1000)
    assert evaluator.evaluate(string_context) == snapshot(
        EvaluationReason(
            value=False,
            reason="Output string 'There is a cat in the box' does not contain expected string 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa...aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'",
        )
    )
    # Case sensitivity
    evaluator = Contains(value='CAT', case_sensitive=True)
    assert evaluator.evaluate(string_context) == snapshot(
        EvaluationReason(
            value=False,
            reason="Output string 'There is a cat in the box' does not contain expected string 'CAT'",
        )
    )
    evaluator = Contains(value='CAT', case_sensitive=False)
    assert evaluator.evaluate(string_context) == snapshot(EvaluationReason(value=True))
    # Test with list output
    list_context = EvaluatorContext[object, list[int], object](
        name='list_test',
        inputs='List items',
        output=[1, 2, 3, 4, 5],
        expected_output=None,
        metadata=None,
        duration=0.1,
        _span_tree=SpanTree(),
        attributes={},
        metrics={},
    )
    # List contains
    evaluator = Contains(value=3)
    assert evaluator.evaluate(list_context) == snapshot(EvaluationReason(value=True))
    # List doesn't contain
    evaluator = Contains(value=6)
    assert evaluator.evaluate(list_context) == snapshot(
        EvaluationReason(value=False, reason='Output [1, 2, 3, 4, 5] does not contain provided value')
    )
    # Test with dict output
    dict_context = EvaluatorContext[object, dict[str, str], object](
        name='dict_test',
        inputs='Dict items',
        output={'key1': 'value1', 'key2': 'value2'},
        expected_output=None,
        metadata=None,
        duration=0.1,
        _span_tree=SpanTree(),
        attributes={},
        metrics={},
    )
    # Dict contains key
    evaluator = Contains(value='key1')
    assert evaluator.evaluate(dict_context) == snapshot(EvaluationReason(value=True))
    # Dict contains subset
    evaluator = Contains(value={'key1': 'value1'})
    assert evaluator.evaluate(dict_context) == snapshot(EvaluationReason(value=True))
    # Dict doesn't contain key-value pair
    evaluator = Contains(value={'key1': 'wrong_value'})
    assert evaluator.evaluate(dict_context) == snapshot(
        EvaluationReason(
            value=False,
            reason="Output dictionary has different value for key 'key1': 'value1' != 'wrong_value'",
        )
    )
    # Dict doesn't contain key
    evaluator = Contains(value='key3')
    assert evaluator.evaluate(dict_context) == snapshot(
        EvaluationReason(
            value=False,
            reason="Output {'key1': 'value1', 'key2': 'value2'} does not contain provided value as a key",
        )
    )
    # Very long keys are truncated
    evaluator = Contains(value={'key1' * 500: 'wrong_value'})
    assert evaluator.evaluate(dict_context) == snapshot(
        EvaluationReason(
            value=False,
            reason="Output dictionary does not contain expected key 'key1key1key1ke...y1key1key1key1'",
        )
    )
    evaluator = Contains(value={'key1': 'wrong_value_' * 500})
    assert evaluator.evaluate(dict_context) == snapshot(
        EvaluationReason(
            value=False,
            reason="Output dictionary has different value for key 'key1': 'value1' != 'wrong_value_wrong_value_wrong_value_wrong_value_w..._wrong_value_wrong_value_wrong_value_wrong_value_'",
        )
    )
async def test_max_duration_evaluator(test_context: EvaluatorContext[TaskInput, TaskOutput, TaskMetadata]):
    """Test the max_duration evaluator."""
    from datetime import timedelta
    # Test with duration under the maximum (using float seconds)
    evaluator = MaxDuration(seconds=0.2)  # test_context has duration=0.1
    result = evaluator.evaluate(test_context)
    assert result is True
    # Test with duration over the maximum
    evaluator = MaxDuration(seconds=0.05)
    result = evaluator.evaluate(test_context)
    assert result is False
    # Test with timedelta
    evaluator = MaxDuration(seconds=timedelta(milliseconds=200))
    result = evaluator.evaluate(test_context)
    assert result is True
    evaluator = MaxDuration(seconds=timedelta(milliseconds=50))
    result = evaluator.evaluate(test_context)
    assert result is False
async def test_span_query_evaluator(
    capfire: CaptureLogfire,
):
    """Test the span_query evaluator."""
    # Create a span tree with a known structure
    with context_subtree() as tree:
        with logfire.span('root_span'):
            with logfire.span('child_span', type='important'):
                pass
    # Create a context with this span tree
    context = EvaluatorContext[object, object, object](
        name='span_test',
        inputs=None,
        output=None,
        expected_output=None,
        metadata=None,
        duration=0.1,
        _span_tree=tree,
        attributes={},
        metrics={},
    )
    # Test positive case: query that matches
    query: SpanQuery = {'name_equals': 'child_span', 'has_attributes': {'type': 'important'}}
    evaluator = HasMatchingSpan(query=query)
    result = evaluator.evaluate(context)
    assert result is True
    # Test negative case: query that doesn't match
    query = {'name_equals': 'non_existent_span'}
    evaluator = HasMatchingSpan(query=query)
    result = evaluator.evaluate(context)
    assert result is False
async def test_import_errors():
    with pytest.raises(
        ImportError,
        match='The `Python` evaluator has been removed for security reasons. See https://github.com/pydantic/pydantic-ai/pull/2808 for more details and a workaround.',
    ):
        from pydantic_evals.evaluators import Python  # pyright: ignore[reportUnusedImport]
    with pytest.raises(
        ImportError,
        match='The `Python` evaluator has been removed for security reasons. See https://github.com/pydantic/pydantic-ai/pull/2808 for more details and a workaround.',
    ):
        from pydantic_evals.evaluators.common import Python  # pyright: ignore[reportUnusedImport] # noqa: F401
    with pytest.raises(
        ImportError,
        match="cannot import name 'Foo' from 'pydantic_evals.evaluators'",
    ):
        from pydantic_evals.evaluators import Foo  # pyright: ignore[reportUnusedImport]
    with pytest.raises(
        ImportError,
        match="cannot import name 'Foo' from 'pydantic_evals.evaluators.common'",
    ):
        from pydantic_evals.evaluators.common import Foo  # pyright: ignore[reportUnusedImport] # noqa: F401
    with pytest.raises(
        AttributeError,
        match="module 'pydantic_evals.evaluators' has no attribute 'Foo'",
    ):
        import pydantic_evals.evaluators as _evaluators
        _evaluators.Foo
    with pytest.raises(
        AttributeError,
        match="module 'pydantic_evals.evaluators.common' has no attribute 'Foo'",
    ):
        import pydantic_evals.evaluators.common as _common
        _common.Foo