"""
Unit tests for visual_tools module.
"""
import asyncio
import base64
import io
import pytest
from unittest.mock import Mock, patch, MagicMock
import numpy as np
from PIL import Image
import cv2
from src.percepta_mcp.tools.visual_tools import VisualAnalysis
from src.percepta_mcp.config import Settings
class TestVisualAnalysis:
"""Test cases for VisualAnalysis class."""
@pytest.fixture
def mock_settings(self):
"""Create mock settings."""
settings = Mock(spec=Settings)
return settings
@pytest.fixture
def visual_analysis(self, mock_settings):
"""Create VisualAnalysis instance."""
return VisualAnalysis(mock_settings)
@pytest.fixture
def sample_image_base64(self):
"""Create a sample base64 encoded image."""
# Create a simple 100x100 RGB image
image = Image.new('RGB', (100, 100), color='red')
buffer = io.BytesIO()
image.save(buffer, format='PNG')
image_bytes = buffer.getvalue()
return base64.b64encode(image_bytes).decode('utf-8')
@pytest.fixture
def sample_image_data_url(self, sample_image_base64):
"""Create a data URL format image."""
return f"data:image/png;base64,{sample_image_base64}"
@pytest.fixture
def sample_cv_image(self):
"""Create a sample OpenCV image."""
return np.zeros((100, 100, 3), dtype=np.uint8)
@pytest.mark.asyncio
async def test_init(self, mock_settings):
"""Test VisualAnalysis initialization."""
visual_analysis = VisualAnalysis(mock_settings)
assert visual_analysis.settings == mock_settings
@pytest.mark.asyncio
async def test_analyze_image_general_success(self, visual_analysis, sample_image_base64):
"""Test successful general image analysis."""
result = await visual_analysis.analyze_image(sample_image_base64, "general")
assert result["success"] is True
assert "analysis" in result
analysis = result["analysis"]
assert analysis["width"] == 100
assert analysis["height"] == 100
assert analysis["format"] == "PNG"
assert analysis["mode"] == "RGB"
assert analysis["analysis_type"] == "general"
assert "description" in analysis
@pytest.mark.asyncio
async def test_analyze_image_text_analysis(self, visual_analysis, sample_image_base64):
"""Test text analysis type."""
result = await visual_analysis.analyze_image(sample_image_base64, "text")
assert result["success"] is True
analysis = result["analysis"]
assert analysis["analysis_type"] == "text"
assert "extracted_text" in analysis
@pytest.mark.asyncio
async def test_analyze_image_objects_analysis(self, visual_analysis, sample_image_base64):
"""Test objects analysis type."""
result = await visual_analysis.analyze_image(sample_image_base64, "objects")
assert result["success"] is True
analysis = result["analysis"]
assert analysis["analysis_type"] == "objects"
assert "detected_objects" in analysis
@pytest.mark.asyncio
async def test_analyze_image_with_prompt(self, visual_analysis, sample_image_base64):
"""Test image analysis with custom prompt."""
prompt = "What do you see in this image?"
result = await visual_analysis.analyze_image(sample_image_base64, "general", prompt)
assert result["success"] is True
analysis = result["analysis"]
assert "prompt_response" in analysis
assert prompt in analysis["prompt_response"]
@pytest.mark.asyncio
async def test_analyze_image_data_url_format(self, visual_analysis, sample_image_data_url):
"""Test image analysis with data URL format."""
result = await visual_analysis.analyze_image(sample_image_data_url, "general")
assert result["success"] is True
assert "analysis" in result
assert result["analysis"]["width"] == 100
assert result["analysis"]["height"] == 100
@pytest.mark.asyncio
async def test_analyze_image_invalid_base64(self, visual_analysis):
"""Test image analysis with invalid base64 data."""
invalid_data = "invalid_base64_data"
result = await visual_analysis.analyze_image(invalid_data, "general")
assert result["success"] is False
assert "error" in result
@pytest.mark.asyncio
async def test_analyze_image_corrupted_image(self, visual_analysis):
"""Test image analysis with corrupted image data."""
# Valid base64 but not a valid image
corrupted_data = base64.b64encode(b"not an image").decode('utf-8')
result = await visual_analysis.analyze_image(corrupted_data, "general")
assert result["success"] is False
assert "error" in result
@pytest.mark.asyncio
async def test_extract_text_success(self, visual_analysis, sample_image_base64):
"""Test successful text extraction."""
result = await visual_analysis.extract_text(sample_image_base64, "eng")
assert result["success"] is True
assert "text" in result
assert result["language"] == "eng"
assert "image_info" in result
assert result["image_info"]["width"] == 100
assert result["image_info"]["height"] == 100
@pytest.mark.asyncio
async def test_extract_text_different_language(self, visual_analysis, sample_image_base64):
"""Test text extraction with different language."""
result = await visual_analysis.extract_text(sample_image_base64, "chi_sim")
assert result["success"] is True
assert result["language"] == "chi_sim"
@pytest.mark.asyncio
async def test_extract_text_data_url(self, visual_analysis, sample_image_data_url):
"""Test text extraction with data URL format."""
result = await visual_analysis.extract_text(sample_image_data_url, "eng")
assert result["success"] is True
assert "text" in result
@pytest.mark.asyncio
async def test_extract_text_invalid_image(self, visual_analysis):
"""Test text extraction with invalid image data."""
invalid_data = "invalid_base64_data"
result = await visual_analysis.extract_text(invalid_data, "eng")
assert result["success"] is False
assert "error" in result
@pytest.mark.asyncio
@patch('cv2.imdecode')
async def test_extract_text_decode_failure(self, mock_imdecode, visual_analysis, sample_image_base64):
"""Test text extraction when cv2.imdecode returns None."""
mock_imdecode.return_value = None
result = await visual_analysis.extract_text(sample_image_base64, "eng")
assert result["success"] is False
assert "Invalid image data" in result["error"]
@pytest.mark.asyncio
async def test_compare_images_structural_method(self, visual_analysis, sample_image_base64):
"""Test image comparison with structural method."""
# Use same image for comparison (should have high similarity)
result = await visual_analysis.compare_images(
sample_image_base64, sample_image_base64, "structural"
)
assert result["success"] is True
assert "similarity" in result
assert result["method"] == "structural"
assert "image1_size" in result
assert "image2_size" in result
# Same image should have high similarity
assert result["similarity"] > 0.9
@pytest.mark.asyncio
async def test_compare_images_histogram_method(self, visual_analysis, sample_image_base64):
"""Test image comparison with histogram method."""
result = await visual_analysis.compare_images(
sample_image_base64, sample_image_base64, "histogram"
)
assert result["success"] is True
assert result["method"] == "histogram"
# Same image should have perfect correlation
assert result["similarity"] > 0.99
@pytest.mark.asyncio
async def test_compare_images_default_method(self, visual_analysis, sample_image_base64):
"""Test image comparison with default method."""
result = await visual_analysis.compare_images(
sample_image_base64, sample_image_base64, "pixel"
)
assert result["success"] is True
assert result["method"] == "pixel"
# Same image should have high similarity
assert result["similarity"] > 0.9
@pytest.mark.asyncio
async def test_compare_images_data_url_format(self, visual_analysis, sample_image_data_url):
"""Test image comparison with data URL format."""
result = await visual_analysis.compare_images(
sample_image_data_url, sample_image_data_url, "structural"
)
assert result["success"] is True
assert result["similarity"] > 0.9
@pytest.mark.asyncio
async def test_compare_images_different_sizes(self, visual_analysis):
"""Test image comparison with different sized images."""
# Create two different sized images
image1 = Image.new('RGB', (100, 100), color='red')
image2 = Image.new('RGB', (200, 150), color='blue')
buffer1 = io.BytesIO()
image1.save(buffer1, format='PNG')
data1 = base64.b64encode(buffer1.getvalue()).decode('utf-8')
buffer2 = io.BytesIO()
image2.save(buffer2, format='PNG')
data2 = base64.b64encode(buffer2.getvalue()).decode('utf-8')
result = await visual_analysis.compare_images(data1, data2, "structural")
assert result["success"] is True
assert result["image1_size"] == (100, 100)
assert result["image2_size"] == (150, 200) # OpenCV format (height, width)
@pytest.mark.asyncio
async def test_compare_images_invalid_data(self, visual_analysis):
"""Test image comparison with invalid image data."""
invalid_data = "invalid_base64_data"
result = await visual_analysis.compare_images(
invalid_data, invalid_data, "structural"
)
assert result["success"] is False
assert "error" in result
@pytest.mark.asyncio
@patch('cv2.imdecode')
async def test_compare_images_decode_failure(self, mock_imdecode, visual_analysis, sample_image_base64):
"""Test image comparison when cv2.imdecode fails."""
mock_imdecode.return_value = None
result = await visual_analysis.compare_images(
sample_image_base64, sample_image_base64, "structural"
)
assert result["success"] is False
assert "Invalid image data" in result["error"]
@pytest.mark.asyncio
async def test_detect_objects_success(self, visual_analysis, sample_image_base64):
"""Test successful object detection."""
result = await visual_analysis.detect_objects(sample_image_base64, 0.7)
assert result["success"] is True
assert "objects" in result
assert result["confidence_threshold"] == 0.7
assert "image_size" in result
assert isinstance(result["objects"], list)
@pytest.mark.asyncio
async def test_detect_objects_default_confidence(self, visual_analysis, sample_image_base64):
"""Test object detection with default confidence threshold."""
result = await visual_analysis.detect_objects(sample_image_base64)
assert result["success"] is True
assert result["confidence_threshold"] == 0.5
@pytest.mark.asyncio
async def test_detect_objects_data_url(self, visual_analysis, sample_image_data_url):
"""Test object detection with data URL format."""
result = await visual_analysis.detect_objects(sample_image_data_url)
assert result["success"] is True
assert "objects" in result
@pytest.mark.asyncio
async def test_detect_objects_invalid_data(self, visual_analysis):
"""Test object detection with invalid image data."""
invalid_data = "invalid_base64_data"
result = await visual_analysis.detect_objects(invalid_data)
assert result["success"] is False
assert "error" in result
@pytest.mark.asyncio
@patch('cv2.imdecode')
async def test_detect_objects_decode_failure(self, mock_imdecode, visual_analysis, sample_image_base64):
"""Test object detection when cv2.imdecode fails."""
mock_imdecode.return_value = None
result = await visual_analysis.detect_objects(sample_image_base64)
assert result["success"] is False
assert "Invalid image data" in result["error"]
@pytest.mark.asyncio
async def test_enhance_image_sharpen(self, visual_analysis, sample_image_base64):
"""Test image enhancement with sharpen filter."""
result = await visual_analysis.enhance_image(sample_image_base64, "sharpen")
assert result["success"] is True
assert "enhanced_image" in result
assert result["enhancement_type"] == "sharpen"
assert result["mime_type"] == "image/png"
@pytest.mark.asyncio
async def test_enhance_image_blur(self, visual_analysis, sample_image_base64):
"""Test image enhancement with blur filter."""
result = await visual_analysis.enhance_image(sample_image_base64, "blur")
assert result["success"] is True
assert result["enhancement_type"] == "blur"
@pytest.mark.asyncio
async def test_enhance_image_brightness(self, visual_analysis, sample_image_base64):
"""Test image enhancement with brightness adjustment."""
result = await visual_analysis.enhance_image(sample_image_base64, "brightness")
assert result["success"] is True
assert result["enhancement_type"] == "brightness"
@pytest.mark.asyncio
async def test_enhance_image_contrast(self, visual_analysis, sample_image_base64):
"""Test image enhancement with contrast adjustment."""
result = await visual_analysis.enhance_image(sample_image_base64, "contrast")
assert result["success"] is True
assert result["enhancement_type"] == "contrast"
@pytest.mark.asyncio
async def test_enhance_image_auto(self, visual_analysis, sample_image_base64):
"""Test image enhancement with auto enhancement."""
result = await visual_analysis.enhance_image(sample_image_base64, "auto")
assert result["success"] is True
assert result["enhancement_type"] == "auto"
@pytest.mark.asyncio
async def test_enhance_image_default_auto(self, visual_analysis, sample_image_base64):
"""Test image enhancement with default (auto) enhancement."""
result = await visual_analysis.enhance_image(sample_image_base64)
assert result["success"] is True
assert result["enhancement_type"] == "auto"
@pytest.mark.asyncio
async def test_enhance_image_data_url(self, visual_analysis, sample_image_data_url):
"""Test image enhancement with data URL format."""
result = await visual_analysis.enhance_image(sample_image_data_url, "sharpen")
assert result["success"] is True
assert "enhanced_image" in result
@pytest.mark.asyncio
async def test_enhance_image_invalid_data(self, visual_analysis):
"""Test image enhancement with invalid image data."""
invalid_data = "invalid_base64_data"
result = await visual_analysis.enhance_image(invalid_data, "sharpen")
assert result["success"] is False
assert "error" in result
@pytest.mark.asyncio
@patch('cv2.imdecode')
async def test_enhance_image_decode_failure(self, mock_imdecode, visual_analysis, sample_image_base64):
"""Test image enhancement when cv2.imdecode fails."""
mock_imdecode.return_value = None
result = await visual_analysis.enhance_image(sample_image_base64, "sharpen")
assert result["success"] is False
assert "Invalid image data" in result["error"]
@pytest.mark.asyncio
async def test_enhance_image_output_is_valid_base64(self, visual_analysis, sample_image_base64):
"""Test that enhanced image output is valid base64."""
result = await visual_analysis.enhance_image(sample_image_base64, "sharpen")
assert result["success"] is True
enhanced_image = result["enhanced_image"]
# Should be able to decode the base64
try:
decoded = base64.b64decode(enhanced_image)
assert len(decoded) > 0
except Exception as e:
pytest.fail(f"Enhanced image is not valid base64: {e}")
@pytest.mark.asyncio
async def test_all_methods_handle_exceptions_gracefully(self, visual_analysis):
"""Test that all methods handle exceptions gracefully."""
methods_to_test = [
("analyze_image", ["invalid_data"]),
("extract_text", ["invalid_data"]),
("compare_images", ["invalid_data", "invalid_data"]),
("detect_objects", ["invalid_data"]),
("enhance_image", ["invalid_data"])
]
for method_name, args in methods_to_test:
method = getattr(visual_analysis, method_name)
result = await method(*args)
assert result["success"] is False
assert "error" in result
assert isinstance(result["error"], str)
@pytest.mark.asyncio
async def test_logging_behavior(self, visual_analysis, sample_image_base64, caplog):
"""Test that appropriate logging occurs."""
with caplog.at_level("INFO"):
await visual_analysis.analyze_image(sample_image_base64, "general")
await visual_analysis.extract_text(sample_image_base64)
await visual_analysis.compare_images(sample_image_base64, sample_image_base64)
await visual_analysis.detect_objects(sample_image_base64)
await visual_analysis.enhance_image(sample_image_base64)
# Check that info logs were created
info_logs = [record for record in caplog.records if record.levelname == "INFO"]
assert len(info_logs) >= 5
# Test error logging
with caplog.at_level("ERROR"):
await visual_analysis.analyze_image("invalid_data")
error_logs = [record for record in caplog.records if record.levelname == "ERROR"]
assert len(error_logs) >= 1
@pytest.mark.asyncio
async def test_concurrent_operations(self, visual_analysis, sample_image_base64):
"""Test that multiple operations can run concurrently."""
tasks = [
visual_analysis.analyze_image(sample_image_base64, "general"),
visual_analysis.extract_text(sample_image_base64),
visual_analysis.detect_objects(sample_image_base64),
visual_analysis.enhance_image(sample_image_base64, "sharpen")
]
results = await asyncio.gather(*tasks)
# All operations should succeed
for result in results:
assert result["success"] is True