# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Tests for xAI models."""
import asyncio
from unittest.mock import MagicMock
import pytest
from genkit.plugins.xai.models import XAIModel
from genkit.types import (
GenerateRequest,
GenerateResponseChunk,
GenerationCommonConfig,
Message,
Part,
Role,
TextPart,
ToolDefinition,
ToolRequestPart,
)
def _create_sample_request() -> GenerateRequest:
"""Create a sample generation request for testing."""
return GenerateRequest(
messages=[
Message(
role=Role.USER,
content=[TextPart(text='Hello, how are you?')],
)
],
config=GenerationCommonConfig(),
tools=[
ToolDefinition(
name='get_weather',
description='Get weather for a location',
input_schema={
'type': 'object',
'properties': {'location': {'type': 'string', 'description': 'Location name'}},
'required': ['location'],
},
)
],
)
@pytest.mark.asyncio
async def test_generate_basic():
"""Test basic generation."""
sample_request = _create_sample_request()
mock_response = MagicMock()
mock_response.content = "Hello! I'm doing well."
mock_response.finish_reason = 'STOP'
mock_response.usage = MagicMock(
prompt_tokens=10,
completion_tokens=15,
total_tokens=25,
)
mock_response.tool_calls = None
mock_chat = MagicMock()
mock_chat.sample = MagicMock(return_value=mock_response)
mock_client = MagicMock()
mock_client.chat = MagicMock()
mock_client.chat.create = MagicMock(return_value=mock_chat)
model = XAIModel(model_name='grok-3', client=mock_client)
response = await model.generate(sample_request)
assert len(response.message.content) == 1
part = response.message.content[0]
actual_part = part.root if isinstance(part, Part) else part
assert isinstance(actual_part, TextPart)
assert actual_part.text == "Hello! I'm doing well."
assert response.usage.input_tokens == 10
assert response.usage.output_tokens == 15
assert response.finish_reason == 'stop'
@pytest.mark.asyncio
async def test_generate_with_config():
mock_response = MagicMock()
mock_response.content = 'Response'
mock_response.finish_reason = 'STOP'
mock_response.usage = MagicMock(
prompt_tokens=5,
completion_tokens=5,
total_tokens=10,
)
mock_response.tool_calls = None
mock_chat = MagicMock()
mock_chat.sample = MagicMock(return_value=mock_response)
mock_client = MagicMock()
mock_client.chat = MagicMock()
mock_client.chat.create = MagicMock(return_value=mock_chat)
model = XAIModel(model_name='grok-3', client=mock_client)
request = GenerateRequest(
messages=[Message(role=Role.USER, content=[TextPart(text='Test')])],
config=GenerationCommonConfig(
temperature=0.7,
max_output_tokens=100,
top_p=0.9,
),
)
await model.generate(request)
call_args = mock_client.chat.create.call_args
assert call_args.kwargs['temperature'] == 0.7
assert call_args.kwargs['max_tokens'] == 100
assert call_args.kwargs['top_p'] == 0.9
def test_to_xai_messages():
mock_client = MagicMock()
model = XAIModel(model_name='grok-3', client=mock_client)
messages = [
Message(role=Role.USER, content=[TextPart(text='Hello')]),
Message(role=Role.MODEL, content=[TextPart(text='Hi there')]),
]
xai_messages = model._to_xai_messages(messages)
assert len(xai_messages) == 2
def test_to_genkit_content():
mock_client = MagicMock()
model = XAIModel(model_name='grok-3', client=mock_client)
mock_response = MagicMock()
mock_response.content = 'Test response'
mock_response.tool_calls = None
content = model._to_genkit_content(mock_response)
assert len(content) == 1
part = content[0]
actual_part = part.root if isinstance(part, Part) else part
assert isinstance(actual_part, TextPart)
assert actual_part.text == 'Test response'
@pytest.mark.asyncio
async def test_streaming_generation():
sample_request = _create_sample_request()
mock_chunk1 = MagicMock()
mock_chunk1.content = 'Hello'
mock_chunk1.choices = []
mock_chunk2 = MagicMock()
mock_chunk2.content = ' world'
mock_chunk2.choices = []
mock_chunk3 = MagicMock()
mock_chunk3.content = '!'
mock_chunk3.choices = []
mock_response = MagicMock()
mock_response.finish_reason = 'STOP'
mock_response.usage = MagicMock(
prompt_tokens=10,
completion_tokens=20,
total_tokens=30,
)
def mock_stream():
yield mock_response, mock_chunk1
yield mock_response, mock_chunk2
yield mock_response, mock_chunk3
mock_chat = MagicMock()
mock_chat.stream = MagicMock(return_value=mock_stream())
mock_client = MagicMock()
mock_client.chat = MagicMock()
mock_client.chat.create = MagicMock(return_value=mock_chat)
model = XAIModel(model_name='grok-3', client=mock_client)
ctx = MagicMock()
ctx.is_streaming = True
collected_chunks = []
def send_chunk(chunk: GenerateResponseChunk):
collected_chunks.append(chunk)
ctx.send_chunk = send_chunk
response = await model.generate(sample_request, ctx)
assert len(collected_chunks) == 3
assert response.usage.input_tokens == 10
assert response.usage.output_tokens == 20
assert response.finish_reason == 'stop'
accumulated_text = ''
for part in response.message.content:
actual_part = part.root if isinstance(part, Part) else part
if isinstance(actual_part, TextPart):
accumulated_text += actual_part.text
assert accumulated_text == 'Hello world!'
@pytest.mark.asyncio
async def test_generate_with_tools():
sample_request = _create_sample_request()
mock_tool_call = MagicMock()
mock_tool_call.id = 'tool_123'
mock_tool_call.function = MagicMock()
mock_tool_call.function.name = 'get_weather'
mock_tool_call.function.arguments = '{"location": "Paris"}'
mock_response = MagicMock()
mock_response.content = None
mock_response.finish_reason = 'TOOL_CALLS'
mock_response.usage = MagicMock(
prompt_tokens=20,
completion_tokens=10,
total_tokens=30,
)
mock_response.tool_calls = [mock_tool_call]
mock_chat = MagicMock()
mock_chat.sample = MagicMock(return_value=mock_response)
mock_client = MagicMock()
mock_client.chat = MagicMock()
mock_client.chat.create = MagicMock(return_value=mock_chat)
model = XAIModel(model_name='grok-3', client=mock_client)
response = await model.generate(sample_request)
assert len(response.message.content) == 1
part = response.message.content[0]
actual_part = part.root if isinstance(part, Part) else part
assert isinstance(actual_part, ToolRequestPart)
assert actual_part.tool_request.name == 'get_weather'
assert actual_part.tool_request.ref == 'tool_123'
@pytest.mark.asyncio
async def test_build_params_basic():
mock_client = MagicMock()
model = XAIModel(model_name='grok-3', client=mock_client)
request = GenerateRequest(
messages=[Message(role=Role.USER, content=[TextPart(text='Test')])],
config=GenerationCommonConfig(),
)
params = model._build_params(request)
assert params['model'] == 'grok-3'
assert 'messages' in params
assert params['max_tokens'] == 4096
@pytest.mark.asyncio
async def test_build_params_with_config():
mock_client = MagicMock()
model = XAIModel(model_name='grok-3', client=mock_client)
request = GenerateRequest(
messages=[Message(role=Role.USER, content=[TextPart(text='Test')])],
config={
'temperature': 0.5,
'max_output_tokens': 200,
'top_p': 0.8,
},
)
params = model._build_params(request)
assert params['temperature'] == 0.5
assert params['max_tokens'] == 200
assert params['top_p'] == 0.8
@pytest.mark.asyncio
async def test_build_params_with_xai_specific_config():
mock_client = MagicMock()
model = XAIModel(model_name='grok-3', client=mock_client)
request = GenerateRequest(
messages=[Message(role=Role.USER, content=[TextPart(text='Test')])],
config={
'temperature': 0.7,
'max_output_tokens': 300,
'deferred': True,
'reasoning_effort': 'high',
'web_search_options': {'enabled': True},
},
)
params = model._build_params(request)
assert params['temperature'] == 0.7
assert params['max_tokens'] == 300
assert params['deferred'] is True
assert params['reasoning_effort'] == 'high'
assert params['web_search_options'] == {'enabled': True}
@pytest.mark.asyncio
async def test_to_genkit_content_parses_json_arguments():
mock_client = MagicMock()
model = XAIModel(model_name='grok-3', client=mock_client)
mock_tool_call = MagicMock()
mock_tool_call.id = 'call_123'
mock_tool_call.function = MagicMock()
mock_tool_call.function.name = 'get_weather'
mock_tool_call.function.arguments = '{"location": "Paris", "unit": "celsius"}'
mock_response = MagicMock()
mock_response.content = 'Some response'
mock_response.tool_calls = [mock_tool_call]
content = model._to_genkit_content(mock_response)
assert len(content) == 2
assert isinstance(content[0], TextPart)
assert content[0].text == 'Some response'
assert isinstance(content[1], ToolRequestPart)
assert content[1].tool_request.name == 'get_weather'
assert isinstance(content[1].tool_request.input, dict)
assert content[1].tool_request.input['location'] == 'Paris'
assert content[1].tool_request.input['unit'] == 'celsius'