"""
Tests for Weather MCP Visualization.
"""
import pytest
import pandas as pd
import matplotlib.pyplot as plt
import sys
from pathlib import Path
# Add src to path for testing
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from weather_mcp.visualization import WeatherVisualizer, create_visualizer
class TestWeatherVisualizer:
"""Test cases for WeatherVisualizer."""
@pytest.fixture
def sample_weather_data(self):
"""Sample weather data for testing."""
return {
"hourly": {
"time": ["2024-01-01T00:00", "2024-01-01T01:00", "2024-01-01T02:00"],
"temperature_2m": [20.5, 21.0, 19.8],
"relative_humidity_2m": [65, 67, 70],
"precipitation_probability": [10, 15, 5],
"wind_speed_10m": [5.2, 4.8, 6.1]
}
}
def test_visualizer_init(self):
"""Test visualizer initialization."""
viz = WeatherVisualizer()
assert viz.figsize == (15, 5)
viz_custom = WeatherVisualizer(figsize=(10, 8))
assert viz_custom.figsize == (10, 8)
def test_prepare_data(self, sample_weather_data):
"""Test data preparation."""
viz = WeatherVisualizer()
df = viz.prepare_data(sample_weather_data)
assert isinstance(df, pd.DataFrame)
assert len(df) == 3
assert "time" in df.columns
assert "temperature_2m" in df.columns
assert pd.api.types.is_datetime64_any_dtype(df['time'])
def test_prepare_data_missing_hourly(self):
"""Test data preparation with missing hourly data."""
viz = WeatherVisualizer()
invalid_data = {"no_hourly": "data"}
with pytest.raises(ValueError, match="Weather data missing 'hourly' key"):
viz.prepare_data(invalid_data)
def test_plot_temperature_trend(self, sample_weather_data):
"""Test temperature trend plotting."""
viz = WeatherVisualizer()
fig = viz.plot_temperature_trend(sample_weather_data)
assert isinstance(fig, plt.Figure)
axes = fig.get_axes()
assert len(axes) == 1
assert "Temperature" in axes[0].get_title()
plt.close(fig)
def test_plot_weather_overview(self, sample_weather_data):
"""Test weather overview plotting."""
viz = WeatherVisualizer()
fig = viz.plot_weather_overview(sample_weather_data)
assert isinstance(fig, plt.Figure)
axes = fig.get_axes()
assert len(axes) == 3 # Temperature, Humidity, Precipitation
plt.close(fig)
def test_plot_all_metrics(self, sample_weather_data):
"""Test all metrics plotting."""
viz = WeatherVisualizer()
fig = viz.plot_all_metrics(sample_weather_data, "Test City")
assert isinstance(fig, plt.Figure)
axes = fig.get_axes()
assert len(axes) == 4 # 2x2 grid
plt.close(fig)
def test_get_weather_summary(self, sample_weather_data):
"""Test weather summary generation."""
viz = WeatherVisualizer()
summary = viz.get_weather_summary(sample_weather_data)
assert "temperature" in summary
assert "humidity" in summary
assert "precipitation" in summary
assert "wind" in summary
assert "period" in summary
# Check temperature stats
temp_stats = summary["temperature"]
assert temp_stats["min"] == 19.8
assert temp_stats["max"] == 21.0
assert abs(temp_stats["mean"] - 20.43) < 0.1
# Check period info
period_info = summary["period"]
assert period_info["hours"] == 3
def test_create_visualizer_function(self):
"""Test the create_visualizer convenience function."""
viz = create_visualizer()
assert isinstance(viz, WeatherVisualizer)
assert viz.figsize == (15, 5)
viz_custom = create_visualizer(figsize=(12, 8))
assert viz_custom.figsize == (12, 8)