test_gemini_token_usage.pyโข3.93 kB
"""Tests for Gemini provider token usage extraction."""
import unittest
from unittest.mock import Mock
from providers.gemini import GeminiModelProvider
class TestGeminiTokenUsage(unittest.TestCase):
"""Test Gemini provider token usage handling."""
def setUp(self):
"""Set up test fixtures."""
self.provider = GeminiModelProvider("test-key")
def test_extract_usage_with_valid_tokens(self):
"""Test token extraction with valid token counts."""
response = Mock()
response.usage_metadata = Mock()
response.usage_metadata.prompt_token_count = 100
response.usage_metadata.candidates_token_count = 50
usage = self.provider._extract_usage(response)
self.assertEqual(usage["input_tokens"], 100)
self.assertEqual(usage["output_tokens"], 50)
self.assertEqual(usage["total_tokens"], 150)
def test_extract_usage_with_none_input_tokens(self):
"""Test token extraction when input_tokens is None (regression test for bug)."""
response = Mock()
response.usage_metadata = Mock()
response.usage_metadata.prompt_token_count = None # This was causing crashes
response.usage_metadata.candidates_token_count = 50
usage = self.provider._extract_usage(response)
# Should not include input_tokens when None
self.assertNotIn("input_tokens", usage)
self.assertEqual(usage["output_tokens"], 50)
# Should not calculate total_tokens when input is None
self.assertNotIn("total_tokens", usage)
def test_extract_usage_with_none_output_tokens(self):
"""Test token extraction when output_tokens is None (regression test for bug)."""
response = Mock()
response.usage_metadata = Mock()
response.usage_metadata.prompt_token_count = 100
response.usage_metadata.candidates_token_count = None # This was causing crashes
usage = self.provider._extract_usage(response)
self.assertEqual(usage["input_tokens"], 100)
# Should not include output_tokens when None
self.assertNotIn("output_tokens", usage)
# Should not calculate total_tokens when output is None
self.assertNotIn("total_tokens", usage)
def test_extract_usage_with_both_none_tokens(self):
"""Test token extraction when both token counts are None."""
response = Mock()
response.usage_metadata = Mock()
response.usage_metadata.prompt_token_count = None
response.usage_metadata.candidates_token_count = None
usage = self.provider._extract_usage(response)
# Should return empty dict when all tokens are None
self.assertEqual(usage, {})
def test_extract_usage_without_usage_metadata(self):
"""Test token extraction when response has no usage_metadata."""
response = Mock(spec=[])
usage = self.provider._extract_usage(response)
# Should return empty dict
self.assertEqual(usage, {})
def test_extract_usage_with_zero_tokens(self):
"""Test token extraction with zero token counts."""
response = Mock()
response.usage_metadata = Mock()
response.usage_metadata.prompt_token_count = 0
response.usage_metadata.candidates_token_count = 0
usage = self.provider._extract_usage(response)
self.assertEqual(usage["input_tokens"], 0)
self.assertEqual(usage["output_tokens"], 0)
self.assertEqual(usage["total_tokens"], 0)
def test_extract_usage_missing_attributes(self):
"""Test token extraction when metadata lacks token count attributes."""
response = Mock()
response.usage_metadata = Mock(spec=[])
usage = self.provider._extract_usage(response)
# Should return empty dict when attributes are missing
self.assertEqual(usage, {})
if __name__ == "__main__":
unittest.main()