test_langchain.py•8.72 kB
from dataclasses import dataclass
from typing import Any
import pytest
from inline_snapshot import snapshot
from pydantic.json_schema import JsonSchemaValue
from pydantic_ai import Agent
from pydantic_ai.ext.langchain import LangChainToolset, tool_from_langchain
@dataclass
class SimulatedLangChainTool:
name: str
description: str
args: dict[str, dict[str, str]]
additional_properties_missing: bool = False
def run(
self,
tool_input: str | dict[str, Any],
verbose: bool | None = None,
start_color: str | None = 'green',
color: str | None = 'green',
callbacks: Any = None,
*,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
run_name: str | None = None,
run_id: Any | None = None,
config: Any | None = None,
tool_call_id: str | None = None,
**kwargs: Any,
) -> Any:
if isinstance(tool_input, dict):
tool_input = dict(sorted(tool_input.items()))
return f'I was called with {tool_input}'
def get_input_jsonschema(self) -> JsonSchemaValue:
if self.additional_properties_missing:
return {
'type': 'object',
'properties': self.args,
}
return {
'type': 'object',
'properties': self.args,
'additionalProperties': False,
}
langchain_tool = SimulatedLangChainTool(
name='file_search',
description='Recursively search for files in a subdirectory that match the regex pattern',
args={
'dir_path': {
'default': '.',
'description': 'Subdirectory to search in.',
'title': 'Dir Path',
'type': 'string',
},
'pattern': {
'description': 'Unix shell regex, where * matches everything.',
'title': 'Pattern',
'type': 'string',
},
},
)
def test_langchain_tool_conversion():
pydantic_tool = tool_from_langchain(langchain_tool)
agent = Agent('test', tools=[pydantic_tool], retries=7)
result = agent.run_sync('foobar')
assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}")
def test_langchain_toolset():
toolset = LangChainToolset([langchain_tool])
agent = Agent('test', toolsets=[toolset], retries=7)
result = agent.run_sync('foobar')
assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}")
def test_langchain_tool_no_additional_properties():
langchain_tool = SimulatedLangChainTool(
name='file_search',
description='Recursively search for files in a subdirectory that match the regex pattern',
args={
'dir_path': {
'default': '.',
'description': 'Subdirectory to search in.',
'title': 'Dir Path',
'type': 'string',
},
'pattern': {
'description': 'Unix shell regex, where * matches everything.',
'title': 'Pattern',
'type': 'string',
},
},
additional_properties_missing=True,
)
pydantic_tool = tool_from_langchain(langchain_tool)
agent = Agent('test', tools=[pydantic_tool], retries=7)
result = agent.run_sync('foobar')
assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}")
def test_langchain_tool_conversion_no_defaults():
langchain_tool = SimulatedLangChainTool(
name='file_search',
description='Recursively search for files in a subdirectory that match the regex pattern',
args={
'dir_path': {
'description': 'Subdirectory to search in.',
'title': 'Dir Path',
'type': 'string',
},
'pattern': {
'description': 'Unix shell regex, where * matches everything.',
'title': 'Pattern',
'type': 'string',
},
},
)
pydantic_tool = tool_from_langchain(langchain_tool)
agent = Agent('test', tools=[pydantic_tool], retries=7)
result = agent.run_sync('foobar')
assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': 'a', 'pattern': 'a'}\"}")
def test_langchain_tool_conversion_no_required():
langchain_tool = SimulatedLangChainTool(
name='file_search',
description='Recursively search for files in a subdirectory that match the regex pattern',
args={
'dir_path': {
'default': '.',
'description': 'Subdirectory to search in.',
'title': 'Dir Path',
'type': 'string',
},
'pattern': {
'default': '*',
'description': 'Unix shell regex, where * matches everything.',
'title': 'Pattern',
'type': 'string',
},
},
)
pydantic_tool = tool_from_langchain(langchain_tool)
agent = Agent('test', tools=[pydantic_tool], retries=7)
result = agent.run_sync('foobar')
assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': '*'}\"}")
def test_langchain_tool_defaults():
langchain_tool = SimulatedLangChainTool(
name='file_search',
description='Recursively search for files in a subdirectory that match the regex pattern',
args={
'dir_path': {
'default': '.',
'description': 'Subdirectory to search in.',
'title': 'Dir Path',
'type': 'string',
},
'pattern': {
'description': 'Unix shell regex, where * matches everything.',
'title': 'Pattern',
'type': 'string',
},
},
)
pydantic_tool = tool_from_langchain(langchain_tool)
result = pydantic_tool.function(pattern='something') # type: ignore
assert result == snapshot("I was called with {'dir_path': '.', 'pattern': 'something'}")
def test_langchain_tool_positional():
langchain_tool = SimulatedLangChainTool(
name='file_search',
description='Recursively search for files in a subdirectory that match the regex pattern',
args={
'pattern': {
'description': 'Unix shell regex, where * matches everything.',
'title': 'Pattern',
'type': 'string',
},
'dir_path': {
'default': '.',
'description': 'Subdirectory to search in.',
'title': 'Dir Path',
'type': 'string',
},
},
)
pydantic_tool = tool_from_langchain(langchain_tool)
with pytest.raises(AssertionError, match='This should always be called with kwargs'):
pydantic_tool.function('something') # type: ignore
def test_langchain_tool_default_override():
langchain_tool = SimulatedLangChainTool(
name='file_search',
description='Recursively search for files in a subdirectory that match the regex pattern',
args={
'dir_path': {
'default': '.',
'description': 'Subdirectory to search in.',
'title': 'Dir Path',
'type': 'string',
},
'pattern': {
'description': 'Unix shell regex, where * matches everything.',
'title': 'Pattern',
'type': 'string',
},
},
)
pydantic_tool = tool_from_langchain(langchain_tool)
result = pydantic_tool.function(pattern='something', dir_path='somewhere') # type: ignore
assert result == snapshot("I was called with {'dir_path': 'somewhere', 'pattern': 'something'}")
def test_simulated_tool_string_input():
tool = SimulatedLangChainTool(
name='file_search',
description='Recursively search for files in a subdirectory that match the regex pattern',
args={
'dir_path': {
'default': '.',
'description': 'Subdirectory to search in.',
'title': 'Dir Path',
'type': 'string',
},
'pattern': {
'description': 'Unix shell regex, where * matches everything.',
'title': 'Pattern',
'type': 'string',
},
},
)
result = tool.run('this string argument')
assert result == snapshot('I was called with this string argument')