import os
import pytest
from unittest.mock import patch, MagicMock
import sys
import asyncio
import json
import tempfile
# Add the current directory to sys.path to import main
sys.path.insert(0, os.path.dirname(__file__))
from main import get_pihole_status, enable_pihole, disable_pihole, get_pihole_summary, pihole_client, load_config
class TestConfiguration:
"""Test configuration loading and server modes."""
def test_load_config_default(self):
"""Test loading default configuration when file doesn't exist."""
with patch('os.path.dirname') as mock_dirname:
mock_dirname.return_value = '/nonexistent'
config = load_config()
assert config["server"]["mode"] == "stdio"
assert config["server"]["port"] == 5000
assert "pihole" in config
assert "base_url" in config["pihole"]
def test_load_config_from_file(self):
"""Test loading configuration from file."""
test_config = {
"server": {
"mode": "port",
"port": 8080
},
"pihole": {
"base_url": "http://test.example.com/api"
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(test_config, f)
config_path = f.name
try:
with patch('os.path.dirname') as mock_dirname:
mock_dirname.return_value = os.path.dirname(config_path)
with patch('builtins.open', create=True) as mock_open:
mock_open.return_value.__enter__.return_value.read.return_value = json.dumps(test_config)
config = load_config()
assert config["server"]["mode"] == "port"
assert config["server"]["port"] == 8080
assert config["pihole"]["base_url"] == "http://test.example.com/api"
finally:
os.unlink(config_path)
@patch('main.mcp.run')
def test_main_stdio_mode(self, mock_run):
"""Test main function with stdio mode."""
with patch('main.config', {"server": {"mode": "stdio", "port": 5000}, "pihole": {"base_url": "http://test.com/api"}}), \
patch.dict(os.environ, {'PIHOLE_APP_PASSWORD': 'dummy'}):
from main import main
main()
mock_run.assert_called_once_with()
@patch('main.mcp.run')
@patch('main.mcp.app', create=True, new_callable=MagicMock)
def test_main_port_mode(self, mock_app, mock_run):
"""Test main function with port mode."""
with patch('main.config', {"server": {"mode": "port", "port": 8080}, "pihole": {"base_url": "http://test.com/api"}}), \
patch.dict(os.environ, {'PIHOLE_APP_PASSWORD': 'dummy'}):
from main import main
main()
mock_run.assert_called_once_with(transport="streamable-http")
@patch('main.mcp.run')
def test_main_invalid_mode(self, mock_run):
"""Test main function with invalid mode."""
with patch('main.config', {"server": {"mode": "invalid", "port": 5000}, "pihole": {"base_url": "http://test.com/api"}}), \
patch.dict(os.environ, {'PIHOLE_APP_PASSWORD': 'dummy'}):
from main import main
with pytest.raises(ValueError, match="Unknown server mode"):
main()
class TestPiHoleMCP:
"""Test suite for Pi-hole MCP server functionality."""
@patch.object(pihole_client, '_get')
@pytest.mark.asyncio
async def test_get_pihole_status_success(self, mock_get):
"""Test successful status retrieval."""
with patch.dict(os.environ, {'PIHOLE_APP_PASSWORD': 'test_password'}):
mock_get.return_value = {"blocking": "enabled", "timer": None}
result = await get_pihole_status()
assert "Pi-hole is enabled (permanent)" in result
@patch.object(pihole_client, '_get')
@pytest.mark.asyncio
async def test_get_pihole_status_error(self, mock_get):
"""Test status retrieval with error."""
with patch.dict(os.environ, {'PIHOLE_APP_PASSWORD': 'test_password'}):
mock_get.side_effect = Exception("Connection error")
result = await get_pihole_status()
assert "Error getting Pi-hole status" in result
@patch.object(pihole_client, '_post')
@pytest.mark.asyncio
async def test_enable_pihole_success(self, mock_post):
"""Test successful enable."""
with patch.dict(os.environ, {'PIHOLE_APP_PASSWORD': 'test_password'}):
mock_post.return_value = {"blocking": "enabled", "timer": None}
result = await enable_pihole()
assert "Pi-hole enabled:" in result
@pytest.mark.asyncio
async def test_enable_pihole_no_password(self):
"""Test enable without password set."""
with patch.dict(os.environ, {}, clear=True):
result = await enable_pihole()
assert "Error enabling Pi-hole" in result
@patch.object(pihole_client, '_post')
@pytest.mark.asyncio
async def test_disable_pihole_success(self, mock_post):
"""Test successful disable."""
with patch.dict(os.environ, {'PIHOLE_APP_PASSWORD': 'test_password'}):
mock_post.return_value = {"blocking": "disabled", "timer": None}
result = await disable_pihole()
assert "Pi-hole disabled permanently:" in result
@patch.object(pihole_client, '_post')
@pytest.mark.asyncio
async def test_disable_pihole_with_duration(self, mock_post):
"""Test disable with duration."""
with patch.dict(os.environ, {'PIHOLE_APP_PASSWORD': 'test_password'}):
mock_post.return_value = {"blocking": "disabled", "timer": 300}
result = await disable_pihole(duration=300)
assert "Pi-hole disabled for 300 seconds:" in result
@pytest.mark.asyncio
async def test_disable_pihole_no_password(self):
"""Test disable without password set."""
with patch.dict(os.environ, {}, clear=True):
result = await disable_pihole()
assert "Error disabling Pi-hole" in result
@patch.object(pihole_client, '_get')
@pytest.mark.asyncio
async def test_get_pihole_summary_success(self, mock_get):
"""Test successful summary retrieval."""
with patch.dict(os.environ, {'PIHOLE_APP_PASSWORD': 'test_password'}):
mock_get.return_value = {
"domains_being_blocked": 1000,
"dns_queries_today": 5000,
"ads_blocked_today": 200,
"ads_percentage_today": 4.0
}
result = await get_pihole_summary()
assert "Pi-hole summary: 1000 domains blocked" in result
@patch.object(pihole_client, '_get')
@pytest.mark.asyncio
async def test_get_pihole_summary_error(self, mock_get):
"""Test summary retrieval with error."""
with patch.dict(os.environ, {'PIHOLE_APP_PASSWORD': 'test_password'}):
mock_get.side_effect = Exception("Network error")
result = await get_pihole_summary()
assert "Error getting Pi-hole summary" in result
class TestPiHoleIntegration:
"""Integration tests that require real Pi-hole connection."""
@pytest.mark.skipif(
not os.environ.get('PIHOLE_APP_PASSWORD'),
reason="PIHOLE_APP_PASSWORD environment variable not set"
)
@pytest.mark.asyncio
async def test_integration_get_status(self):
"""Integration test: Get Pi-hole status."""
result = await get_pihole_status()
assert "Pi-hole is" in result
@pytest.mark.skipif(
not os.environ.get('PIHOLE_APP_PASSWORD'),
reason="PIHOLE_APP_PASSWORD environment variable not set"
)
@pytest.mark.asyncio
async def test_integration_get_summary(self):
"""Integration test: Get Pi-hole summary."""
result = await get_pihole_summary()
assert "Pi-hole summary:" in result
@pytest.mark.skipif(
not os.environ.get('PIHOLE_APP_PASSWORD'),
reason="PIHOLE_APP_PASSWORD environment variable not set"
)
@pytest.mark.asyncio
async def test_integration_enable_disable(self):
"""Integration test: Enable and disable Pi-hole (be careful - this affects blocking!)."""
# First get current status
status_result = await get_pihole_status()
print(f"Current status: {status_result}")
# Enable
enable_result = await enable_pihole()
print(f"Enable result: {enable_result}")
assert "Pi-hole enabled:" in enable_result
# Wait a bit
await asyncio.sleep(2)
# Disable for 10 seconds
disable_result = await disable_pihole(duration=10)
print(f"Disable result: {disable_result}")
assert "Pi-hole disabled for 10 seconds:" in disable_result