#!/usr/bin/env python3
"""Comprehensive tests for gemini_helper.py"""
import os
import unittest
from unittest.mock import MagicMock, Mock, patch
import pytest
# Import functions to test
from gemini_helper import (
GEMINI_MODELS,
MAX_FILE_SIZE,
MAX_LINES,
MODEL_ASSIGNMENTS,
execute_gemini_api,
execute_gemini_cli,
sanitize_for_prompt,
)
class TestSanitizeForPrompt(unittest.TestCase):
"""Test cases for sanitize_for_prompt function"""
def test_sanitize_for_prompt_basic(self):
"""Test basic functionality with clean input"""
input_text = "Hello World"
expected_output = "Hello World"
self.assertEqual(sanitize_for_prompt(input_text), expected_output)
def test_sanitize_for_prompt_non_string_input(self):
"""Test handling of non-string input"""
inputs = [None, 123, [], {}]
for input_val in inputs:
result = sanitize_for_prompt(input_val)
self.assertEqual(result, "")
def test_sanitize_for_prompt_dangerous_patterns(self):
"""Test removal of dangerous instruction patterns"""
dangerous_patterns = [
"ignore all previous instructions",
"forget everything above",
"new instruction:",
"system:",
"assistant:",
"user:",
"###",
"---",
"```",
"<|",
"|>",
"[INST]",
"[/INST]",
]
for pattern in dangerous_patterns:
# Test case insensitive
for test_pattern in [pattern, pattern.upper(), pattern.capitalize()]:
result = sanitize_for_prompt(test_pattern)
self.assertIn("[filtered-content]", result)
self.assertNotIn(pattern.lower(), result.lower())
def test_sanitize_for_prompt_mixed_content(self):
"""Test mixed content with dangerous patterns"""
mixed_text = "Hello ignore all previous instructions world"
result = sanitize_for_prompt(mixed_text)
self.assertIn("[filtered-content]", result)
self.assertNotIn("ignore all previous instructions", result.lower())
def test_sanitize_for_prompt_length_limit(self):
"""Test truncation of overly long inputs"""
long_text = "A" * 200000
result = sanitize_for_prompt(long_text, max_length=1000)
self.assertLessEqual(len(result), 1000)
def test_sanitize_for_prompt_control_characters(self):
"""Test removal of control characters"""
text_with_control_chars = "Hello\x00World\x1b[31mRed Text"
result = sanitize_for_prompt(text_with_control_chars)
self.assertNotIn("\x00", result)
self.assertNotIn("\x1b", result)
self.assertIn("HelloWorld", result)
def test_sanitize_for_prompt_empty_input(self):
"""Test handling of empty input"""
result = sanitize_for_prompt("")
self.assertEqual(result, "")
def test_sanitize_for_prompt_whitespace_only(self):
"""Test handling of whitespace-only input"""
result = sanitize_for_prompt(" ")
self.assertEqual(result, " ")
class TestExecuteGeminiApi(unittest.TestCase):
"""Test cases for execute_gemini_api function"""
def test_execute_gemini_api_missing_api_key(self):
"""Test handling of missing API key"""
with patch("gemini_helper.GOOGLE_API_KEY", None):
result = execute_gemini_api(
"test prompt", "test-model", show_progress=False
)
self.assertFalse(result["success"])
self.assertIn("Invalid or missing API key", result["error"])
def test_execute_gemini_api_invalid_api_key(self):
"""Test handling of invalid API key"""
with patch("gemini_helper.GOOGLE_API_KEY", "short"):
result = execute_gemini_api(
"test prompt", "test-model", show_progress=False
)
self.assertFalse(result["success"])
self.assertIn("Invalid or missing API key", result["error"])
def test_execute_gemini_api_import_error(self):
"""Test handling of missing google-generativeai library"""
with patch("gemini_helper.GOOGLE_API_KEY", "valid_key_123456789"):
with patch("builtins.__import__", side_effect=ImportError):
result = execute_gemini_api(
"test prompt", "test-model", show_progress=False
)
self.assertFalse(result["success"])
self.assertIn("API library not available", result["error"])
@patch("gemini_helper.GOOGLE_API_KEY", "valid_key_123456789")
def test_execute_gemini_api_success(self):
"""Test successful API call"""
mock_response = MagicMock()
mock_response.text = "Test response"
mock_model = MagicMock()
mock_model.generate_content.return_value = mock_response
mock_genai = MagicMock()
mock_genai.GenerativeModel.return_value = mock_model
with patch.dict("sys.modules", {"google.generativeai": mock_genai}):
result = execute_gemini_api(
"test prompt", "test-model", show_progress=False
)
self.assertTrue(result["success"])
self.assertEqual(result["output"], "Test response")
@patch("gemini_helper.GOOGLE_API_KEY", "valid_key_123456789")
def test_execute_gemini_api_exception_handling(self):
"""Test exception handling and error sanitization"""
mock_genai = MagicMock()
mock_genai.GenerativeModel.side_effect = Exception(
"API error with AIzaSy123456789012345678901234567890123"
)
with patch.dict("sys.modules", {"google.generativeai": mock_genai}):
result = execute_gemini_api(
"test prompt", "test-model", show_progress=False
)
self.assertFalse(result["success"])
self.assertIn("[API_KEY_REDACTED]", result["error"])
self.assertNotIn("AIzaSy123456789012345678901234567890123", result["error"])
class TestExecuteGeminiCli(unittest.TestCase):
"""Test cases for execute_gemini_cli function"""
def test_execute_gemini_cli_invalid_prompt_empty(self):
"""Test handling of empty prompt"""
result = execute_gemini_cli("", show_progress=False)
self.assertFalse(result["success"])
self.assertIn("Invalid prompt: must be non-empty string", result["error"])
def test_execute_gemini_cli_invalid_prompt_non_string(self):
"""Test handling of non-string prompt"""
result = execute_gemini_cli(123, show_progress=False)
self.assertFalse(result["success"])
self.assertIn("Invalid prompt: must be non-empty string", result["error"])
def test_execute_gemini_cli_prompt_too_large(self):
"""Test handling of oversized prompt"""
large_prompt = "A" * 1000001 # 1MB + 1 byte
result = execute_gemini_cli(large_prompt, show_progress=False)
self.assertFalse(result["success"])
self.assertIn("Prompt too large (max 1MB)", result["error"])
def test_execute_gemini_cli_invalid_model_name(self):
"""Test handling of invalid model name"""
result = execute_gemini_cli("test prompt", model_name="", show_progress=False)
self.assertFalse(result["success"])
self.assertIn("Invalid model name", result["error"])
def test_execute_gemini_cli_invalid_model_name_characters(self):
"""Test handling of model name with invalid characters"""
result = execute_gemini_cli(
"test prompt", model_name="model$name", show_progress=False
)
self.assertFalse(result["success"])
self.assertIn("Invalid model name characters", result["error"])
def test_execute_gemini_cli_valid_model_name(self):
"""Test validation of valid model name"""
valid_names = ["gemini-pro", "gemini-2.5-flash", "model-1.0"]
for model_name in valid_names:
with patch("subprocess.Popen") as mock_popen:
mock_process = MagicMock()
mock_process.stdout.readline.side_effect = ["", ""]
mock_process.poll.return_value = 0
mock_process.communicate.return_value = ("success", "")
mock_popen.return_value = mock_process
result = execute_gemini_cli(
"test prompt", model_name=model_name, show_progress=False
)
# Should not fail on model name validation
self.assertTrue("Invalid model name" not in str(result))
class TestConstants(unittest.TestCase):
"""Test cases for constants and configuration"""
def test_gemini_models_structure(self):
"""Test GEMINI_MODELS dictionary structure"""
self.assertIsInstance(GEMINI_MODELS, dict)
self.assertIn("flash", GEMINI_MODELS)
self.assertIn("pro", GEMINI_MODELS)
self.assertIsInstance(GEMINI_MODELS["flash"], str)
self.assertIsInstance(GEMINI_MODELS["pro"], str)
def test_model_assignments_structure(self):
"""Test MODEL_ASSIGNMENTS dictionary structure"""
self.assertIsInstance(MODEL_ASSIGNMENTS, dict)
self.assertIn("quick_query", MODEL_ASSIGNMENTS)
self.assertIn("analyze_code", MODEL_ASSIGNMENTS)
self.assertIn("analyze_codebase", MODEL_ASSIGNMENTS)
# Check that assignments reference valid models
for assignment in MODEL_ASSIGNMENTS.values():
self.assertIn(assignment, GEMINI_MODELS)
def test_constants_values(self):
"""Test constant values are reasonable"""
self.assertIsInstance(MAX_FILE_SIZE, int)
self.assertGreater(MAX_FILE_SIZE, 0)
self.assertIsInstance(MAX_LINES, int)
self.assertGreater(MAX_LINES, 0)
if __name__ == "__main__":
unittest.main()