# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Tests for Hugging Face model implementation."""
from unittest.mock import MagicMock, patch
import pytest
from genkit.core.typing import (
GenerateRequest,
Message,
Part,
Role,
TextPart,
ToolDefinition,
ToolRequest,
ToolRequestPart,
ToolResponse,
ToolResponsePart,
)
from genkit.plugins.huggingface.models import HuggingFaceModel
@pytest.fixture
def mock_client() -> MagicMock:
"""Create a mock InferenceClient."""
return MagicMock()
@pytest.fixture
def model(mock_client: MagicMock) -> HuggingFaceModel:
"""Create a HuggingFaceModel with mocked client."""
with patch('genkit.plugins.huggingface.models.InferenceClient', return_value=mock_client):
return HuggingFaceModel(
model='meta-llama/Llama-3.3-70B-Instruct',
token='test-token',
)
def test_model_initialization(model: HuggingFaceModel) -> None:
"""Test model initialization."""
assert model.name == 'meta-llama/Llama-3.3-70B-Instruct'
def test_model_with_provider() -> None:
"""Test model initialization with provider."""
with patch('genkit.plugins.huggingface.models.InferenceClient'):
model = HuggingFaceModel(
model='meta-llama/Llama-3.3-70B-Instruct',
token='test-token',
provider='groq',
)
assert model.provider == 'groq'
def test_get_model_info(model: HuggingFaceModel) -> None:
"""Test get_model_info returns expected structure."""
info = model.get_model_info()
assert info is not None
assert 'name' in info
assert 'supports' in info
def test_convert_messages_text_only(model: HuggingFaceModel) -> None:
"""Test converting simple text messages."""
messages = [
Message(role=Role.USER, content=[Part(root=TextPart(text='Hello'))]),
Message(role=Role.MODEL, content=[Part(root=TextPart(text='Hi there!'))]),
]
hf_messages = model._convert_messages(messages)
assert len(hf_messages) == 2
assert hf_messages[0]['role'] == 'user'
assert hf_messages[0]['content'] == 'Hello'
assert hf_messages[1]['role'] == 'assistant'
assert hf_messages[1]['content'] == 'Hi there!'
def test_convert_messages_system_role(model: HuggingFaceModel) -> None:
"""Test converting system messages."""
messages = [
Message(role=Role.SYSTEM, content=[Part(root=TextPart(text='You are helpful.'))]),
Message(role=Role.USER, content=[Part(root=TextPart(text='Hello'))]),
]
hf_messages = model._convert_messages(messages)
assert len(hf_messages) == 2
assert hf_messages[0]['role'] == 'system'
assert hf_messages[0]['content'] == 'You are helpful.'
def test_convert_messages_with_tool_request(model: HuggingFaceModel) -> None:
"""Test converting messages with tool requests."""
messages = [
Message(
role=Role.MODEL,
content=[
Part(
root=ToolRequestPart(
tool_request=ToolRequest(
ref='call_123',
name='get_weather',
input={'city': 'Paris'},
)
)
)
],
),
]
hf_messages = model._convert_messages(messages)
assert len(hf_messages) == 1
assert hf_messages[0]['role'] == 'assistant'
assert 'tool_calls' in hf_messages[0]
assert len(hf_messages[0]['tool_calls']) == 1
assert hf_messages[0]['tool_calls'][0]['id'] == 'call_123'
assert hf_messages[0]['tool_calls'][0]['function']['name'] == 'get_weather'
def test_convert_messages_with_tool_response(model: HuggingFaceModel) -> None:
"""Test converting messages with tool responses."""
messages = [
Message(
role=Role.TOOL,
content=[
Part(
root=ToolResponsePart(
tool_response=ToolResponse(
ref='call_123',
name='get_weather',
output={'temperature': 20},
)
)
)
],
),
]
hf_messages = model._convert_messages(messages)
assert len(hf_messages) == 1
assert hf_messages[0]['role'] == 'tool'
assert hf_messages[0]['tool_call_id'] == 'call_123'
def test_convert_tools(model: HuggingFaceModel) -> None:
"""Test converting tool definitions."""
tools = [
ToolDefinition(
name='get_weather',
description='Get weather for a city',
input_schema={
'type': 'object',
'properties': {'city': {'type': 'string'}},
'required': ['city'],
},
),
]
hf_tools = model._convert_tools(tools)
assert len(hf_tools) == 1
assert hf_tools[0]['type'] == 'function'
assert hf_tools[0]['function']['name'] == 'get_weather'
assert hf_tools[0]['function']['description'] == 'Get weather for a city'
assert 'additionalProperties' in hf_tools[0]['function']['parameters']
def test_get_response_format_json(model: HuggingFaceModel) -> None:
"""Test response format for JSON output."""
from genkit.core.typing import OutputConfig
output = OutputConfig(format='json')
result = model._get_response_format(output)
assert result == {'type': 'json'}
def test_get_response_format_json_with_schema(model: HuggingFaceModel) -> None:
"""Test response format for JSON with schema."""
from genkit.core.typing import OutputConfig
schema = {'type': 'object', 'properties': {'name': {'type': 'string'}}}
output = OutputConfig(format='json', schema=schema)
result = model._get_response_format(output)
assert result == {'type': 'json', 'value': schema}
def test_get_response_format_text(model: HuggingFaceModel) -> None:
"""Test response format for text output returns None."""
from genkit.core.typing import OutputConfig
output = OutputConfig(format='text')
result = model._get_response_format(output)
assert result is None
@patch('genkit.plugins.huggingface.models.InferenceClient')
@pytest.mark.asyncio
async def test_generate_simple_request(mock_client_class: MagicMock) -> None:
"""Test simple generate request."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
# Mock response
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message = MagicMock()
mock_response.choices[0].message.content = 'Hello, world!'
mock_response.choices[0].message.tool_calls = None
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
mock_response.usage.total_tokens = 15
mock_client.chat_completion.return_value = mock_response
model = HuggingFaceModel(
model='meta-llama/Llama-3.3-70B-Instruct',
token='test-token',
)
request = GenerateRequest(
messages=[Message(role=Role.USER, content=[Part(root=TextPart(text='Hi'))])],
)
response = await model.generate(request)
assert response.message is not None
assert response.message.role == Role.MODEL
assert len(response.message.content) == 1
assert response.message.content[0].root.text == 'Hello, world!'
assert response.usage is not None
assert response.usage.input_tokens == 10
assert response.usage.output_tokens == 5
@patch('genkit.plugins.huggingface.models.InferenceClient')
@pytest.mark.asyncio
async def test_generate_with_tool_calls(mock_client_class: MagicMock) -> None:
"""Test generate with tool calls in response."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
# Mock tool call in response
mock_tool_call = MagicMock()
mock_tool_call.id = 'call_abc'
mock_tool_call.function = MagicMock()
mock_tool_call.function.name = 'get_weather'
mock_tool_call.function.arguments = '{"city": "Paris"}'
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message = MagicMock()
mock_response.choices[0].message.content = None
mock_response.choices[0].message.tool_calls = [mock_tool_call]
mock_response.usage = None
mock_client.chat_completion.return_value = mock_response
model = HuggingFaceModel(
model='meta-llama/Llama-3.3-70B-Instruct',
token='test-token',
)
request = GenerateRequest(
messages=[Message(role=Role.USER, content=[Part(root=TextPart(text='Weather in Paris?'))])],
tools=[
ToolDefinition(
name='get_weather',
description='Get weather',
input_schema={'type': 'object', 'properties': {'city': {'type': 'string'}}},
)
],
)
response = await model.generate(request)
assert response.message is not None
assert len(response.message.content) == 1
tool_part = response.message.content[0].root
assert isinstance(tool_part, ToolRequestPart)
assert tool_part.tool_request.name == 'get_weather'
assert tool_part.tool_request.ref == 'call_abc'
def test_to_generate_fn(model: HuggingFaceModel) -> None:
"""Test to_generate_fn returns callable."""
fn = model.to_generate_fn()
assert callable(fn)
assert fn == model.generate