import asyncio
import os
import requests
import json
import mcp.server.streamable_http as streamable_http
from fastapi import Request, HTTPException
import uvicorn
from mcp.server.fastmcp import FastMCP
from fastapi import FastAPI
# Create FastAPI app with middleware
app = FastAPI()
# Initialize FastMCP server
mcp = FastMCP("pihole-control")
mcp.app = app
# Load configuration
def load_config():
config_path = os.path.join(os.path.dirname(__file__), 'config.json')
try:
with open(config_path, 'r') as f:
return json.load(f)
except FileNotFoundError:
# Default configuration
return {
"server": {
"mode": "stdio",
"port": 5000
},
"pihole": {
"base_url": "http://192.168.68.59:8081/api"
}
}
config = load_config()
PIHOLE_BASE_URL = config["pihole"]["base_url"]
class PiHoleClient:
def __init__(self):
self.session_id = None
self.session_validity = 0
def _ensure_session(self):
"""Ensure we have a valid session, refreshing if necessary."""
import time
current_time = time.time()
# If we have a session and it's still valid for at least 5 minutes, use it
if self.session_id and self.session_validity > current_time + 300:
return
# Get new session
app_password = os.environ.get("PIHOLE_APP_PASSWORD")
if not app_password:
raise ValueError("PIHOLE_APP_PASSWORD environment variable not set")
response = requests.post(f"{PIHOLE_BASE_URL}/auth", json={"password": app_password})
response.raise_for_status()
data = response.json()
session = data.get("session", {})
if not session.get("valid"):
raise ValueError("Authentication failed")
self.session_id = session.get("sid")
# Session lasts 30 minutes (1800 seconds) by default
self.session_validity = current_time + 1800
def _get(self, endpoint):
"""Make authenticated GET request."""
self._ensure_session()
response = requests.get(f"{PIHOLE_BASE_URL}{endpoint}?sid={self.session_id}")
response.raise_for_status()
return response.json()
def _post(self, endpoint, data):
"""Make authenticated POST request."""
self._ensure_session()
response = requests.post(f"{PIHOLE_BASE_URL}{endpoint}?sid={self.session_id}", json=data)
response.raise_for_status()
return response.json()
# Global client instance
pihole_client = PiHoleClient()
@mcp.tool()
async def get_pihole_status():
"""Get the current status of Pi-hole (enabled/disabled)."""
try:
data = pihole_client._get("/dns/blocking")
blocking = data.get("blocking")
timer = data.get("timer")
status = "enabled" if blocking == "enabled" else "disabled"
if timer:
return f"Pi-hole is {status} (temporary, {timer} seconds remaining)"
else:
return f"Pi-hole is {status} (permanent)"
except Exception as e:
return f"Error getting Pi-hole status. Please check the configuration and credentials."
@mcp.tool()
async def enable_pihole():
"""Enable Pi-hole ad blocking."""
try:
data = pihole_client._post("/dns/blocking", {"blocking": True})
return f"Pi-hole enabled: {data}"
except Exception as e:
return f"Error enabling Pi-hole. Please check the configuration and credentials."
@mcp.tool()
async def disable_pihole(duration: int = 0):
"""Disable Pi-hole ad blocking. Optionally specify duration in seconds."""
try:
payload = {"blocking": False}
if duration > 0:
payload["timer"] = duration
data = pihole_client._post("/dns/blocking", payload)
if duration > 0:
return f"Pi-hole disabled for {duration} seconds: {data}"
else:
return f"Pi-hole disabled permanently: {data}"
except Exception as e:
return f"Error disabling Pi-hole. Please check the configuration and credentials."
@mcp.tool()
async def get_pihole_summary():
"""Get a summary of Pi-hole statistics."""
try:
data = pihole_client._get("/stats/summary")
# Extract key statistics
domains_blocked = data.get("domains_being_blocked", "N/A")
dns_queries_today = data.get("dns_queries_today", "N/A")
ads_blocked_today = data.get("ads_blocked_today", "N/A")
ads_percentage_today = data.get("ads_percentage_today", "N/A")
return f"Pi-hole summary: {domains_blocked} domains blocked, {dns_queries_today} DNS queries today, {ads_blocked_today} ads blocked today ({ads_percentage_today}%)."
except Exception as e:
return f"Error getting Pi-hole summary. Please check the configuration and credentials."
def main():
# Check for required environment variable
if not os.environ.get("PIHOLE_APP_PASSWORD"):
raise ValueError("PIHOLE_APP_PASSWORD environment variable not set")
server_config = config["server"]
mode = server_config.get("mode", "stdio")
if mode == "stdio":
mcp.run()
elif mode == "port":
port = server_config.get("port", 5000)
# Configure the server to run on the specified port
mcp.settings.host = "0.0.0.0"
mcp.settings.port = port
mcp.run(transport="streamable-http")
else:
raise ValueError(f"Unknown server mode: {mode}. Supported modes: 'stdio', 'port'")
if __name__ == "__main__":
main()