#!/usr/bin/env python3
import json
import os
import sys
import tempfile
import traceback
import logging
import ipaddress
import ast
import argparse
from typing import Any, Dict, List, Literal, TypedDict, Optional
import asyncio
from mcp.server.fastmcp import FastMCP
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class OutletConfig(TypedDict):
name: str
type: Literal["standard", "critical", "prohibited"]
description: str
status: Optional[str] # Added during runtime
class SwitchConfig(TypedDict):
alias: str
ip_address: str
username: str
password: str
outlets: Dict[str, OutletConfig]
controller_name: Optional[str]
description: Optional[str]
try:
from power_switch_pro import PowerSwitchPro as RealDLIPowerSwitch
except ImportError:
# Define a dummy class if the library isn't installed, so the server can still start
# All real calls will fail, but mock calls will work.
class RealDLIPowerSwitch:
def __init__(self, *args, **kwargs):
raise ImportError("power-switch-pro library not found. Please install it to use real hardware.")
class TimeoutDLIPowerSwitch(RealDLIPowerSwitch):
"""
Wrapper around RealDLIPowerSwitch to enforce a request timeout.
The power-switch-pro library hardcodes a 30s timeout, which might be too long.
This wrapper monkeypatches the session.request method for this instance.
"""
def __init__(self, host, username, password, timeout=5.0, **kwargs):
super().__init__(host, username, password, **kwargs)
# Check if 'session' exists (it might not if RealDLIPowerSwitch is the dummy class)
if hasattr(self, 'session'):
original_request = self.session.request
def patched_request(method, url, *args, **kwargs):
# Force the timeout, ignoring what the library might pass (e.g. 30)
kwargs['timeout'] = timeout
return original_request(method, url, *args, **kwargs)
self.session.request = patched_request
CONFIG_FILE = os.environ.get("DLI_MCP_CONFIG", "switches_config.json")
CONFIG_LOCK = asyncio.Lock()
server = FastMCP(name="dli-mcp-server")
def validate_ip(ip: str) -> None:
try:
ipaddress.ip_address(ip)
except ValueError:
raise ValueError(f"Invalid IP address: {ip}")
def get_client(switch_config: SwitchConfig):
"""
Returns a real or mock DLI Power Switch client based on the environment.
"""
if os.environ.get("DLI_MCP_ENV") == "TEST":
# Add tests dir to path to find mock_device
tests_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tests')
if tests_dir not in sys.path:
sys.path.append(tests_dir)
try:
from mock_device import MockDLIPowerSwitch
except ImportError:
# Fallback if tests dir isn't structured as expected or file missing
raise ImportError("Could not import mock_device from tests directory. Ensure 'tests/mock_device.py' exists.")
return MockDLIPowerSwitch(
ip=switch_config["ip_address"],
username=switch_config.get("username"),
password=switch_config.get("password"),
timeout=2.0
)
else:
return TimeoutDLIPowerSwitch(
host=switch_config["ip_address"],
username=switch_config.get("username"),
password=switch_config.get("password"),
timeout=5.0 # Set explicit timeout
)
def load_config() -> Dict[str, Any]:
if not os.path.exists(CONFIG_FILE):
raise FileNotFoundError(f"Configuration file not found: {CONFIG_FILE}")
with open(CONFIG_FILE, "r", encoding='utf-8') as f:
try:
return json.load(f)
except json.JSONDecodeError as e:
raise ValueError(f"Configuration file '{CONFIG_FILE}' contains invalid JSON: {e}")
def save_config(config: Dict[str, Any]):
# Write to a temp file in the same directory to ensure atomic rename works
dir_name = os.path.dirname(os.path.abspath(CONFIG_FILE))
# delete=False is required so we can close it and then replace the original file
with tempfile.NamedTemporaryFile('w', dir=dir_name, delete=False, encoding='utf-8') as tf:
json.dump(config, tf, indent=2)
temp_name = tf.name
tf.flush()
os.fsync(tf.fileno()) # Force write to physical disk
# Atomic replacement
os.replace(temp_name, CONFIG_FILE)
def resolve_switch(identifier: str) -> SwitchConfig:
config = load_config()
for switch in config["switches"]:
if switch["alias"] == identifier or switch["ip_address"] == identifier:
return switch
raise ValueError("Device not found")
def resolve_outlet(switch: SwitchConfig, identifier: str) -> str:
if identifier in switch["outlets"]:
return identifier
for index, outlet_data in switch["outlets"].items():
if outlet_data.get("name", "").lower() == identifier.lower():
return index
raise ValueError("Outlet not found")
# --- Synchronous Helpers for Threading ---
def _fetch_switch_status_sync(switch_config: SwitchConfig) -> Optional[Dict[str, Any]]:
"""
Synchronously fetches status for all outlets on a switch.
Returns a dict {outlet_index_str: status} or None on failure.
"""
try:
client = get_client(switch_config)
statuses = {}
# get_all_states returns a list of bools. Use its length to determine outlet count.
# This bypasses potential issues with client.outlets.count() on some firmware.
all_states = client.outlets.get_all_states()
for i, state in enumerate(all_states):
statuses[str(i+1)] = state
return statuses
except Exception as e:
logger.error(f"Error fetching status for switch {switch_config.get('alias')}: {e}")
return None
def _power_action_sync(switch_config: SwitchConfig, outlet_index: int, action: str) -> None:
"""
Synchronously performs a power action.
"""
client = get_client(switch_config)
outlet = client.outlets[outlet_index]
if action == "on":
outlet.on()
elif action == "off":
outlet.off()
elif action == "cycle":
outlet.cycle()
def _list_outlets_sync(switch_config: SwitchConfig) -> List[Dict[str, Any]]:
"""
Synchronously lists outlets with live status.
"""
outlets_info = []
try:
client = get_client(switch_config)
# Determine outlet count dynamically
try:
outlet_count = len(client.outlets.get_all_states())
except Exception:
# Fallback if get_all_states fails
outlet_count = 8
for i in range(1, outlet_count + 1):
try:
hw_outlet = client.outlets[i-1]
# Accessing properties might be blocking
state = hw_outlet.state
name = hw_outlet.name
outlet_config = switch_config["outlets"].get(str(i), {})
outlets_info.append({
"index": i,
"name": name,
"status": "on" if state else "off",
"type": outlet_config.get("type", "standard"),
"description": outlet_config.get("description", "")
})
except Exception as e:
# Append warning or placeholder for inaccessible outlets
outlets_info.append({
"index": i,
"name": f"Unknown (Error: {e})",
"status": "unknown",
"type": "unknown",
"description": ""
})
except Exception as e:
# Fallback error handling if client creation fails
for i in range(1, 9):
outlets_info.append({
"index": i,
"name": f"Unknown (Error: {e})",
"status": "unknown",
"type": "unknown",
"description": ""
})
return outlets_info
def _update_outlet_name_sync(switch_config: SwitchConfig, outlet_index: int, new_name: str) -> None:
"""
Synchronously updates the outlet name on the hardware.
"""
client = get_client(switch_config)
outlet = client.outlets[outlet_index]
if hasattr(outlet, "set_name"):
outlet.set_name(new_name)
else:
outlet.name = new_name
def _sync_switch_config_sync(switch_config: SwitchConfig) -> None:
"""
Connects to hardware and updates the switch_config dict in-place.
"""
client = get_client(switch_config)
# Attempt to get controller name
controller_name = getattr(client, 'name', None) or getattr(client, 'model', None)
if controller_name:
switch_config["controller_name"] = controller_name
# Determine outlet count dynamically
try:
outlet_count = len(client.outlets.get_all_states())
except Exception:
# Fallback if get_all_states fails
outlet_count = 8
for i in range(1, outlet_count + 1):
try:
hw_outlet = client.outlets[i-1]
index_str = str(i)
# Accessing name might be blocking
name = hw_outlet.name
if index_str in switch_config["outlets"]:
switch_config["outlets"][index_str]["name"] = name
else:
switch_config["outlets"].setdefault(index_str, {})["name"] = name
switch_config["outlets"].get(index_str, {})["type"] = "standard"
switch_config["outlets"].get(index_str, {})["description"] = ""
except Exception:
# Assuming the error is due to a non-existent outlet,
# we can print a warning and continue.
logger.warning(f"Could not sync outlet {i} on {switch_config.get('alias', 'unknown')}.", exc_info=True)
# --- MCP Tools ---
@server.tool()
async def get_inventory() -> Dict[str, Any]:
config = load_config()
tasks = []
# Run status fetches for all switches in parallel threads
for switch_config in config["switches"]:
tasks.append(asyncio.to_thread(_fetch_switch_status_sync, switch_config))
results = await asyncio.gather(*tasks)
for switch_config, result in zip(config["switches"], results):
if result:
# result is a dict of {"index": status}
for index_str, status in result.items():
# Only update if the outlet is already known in the config,
# or consider adding it if we want auto-discovery (but let's stick to config for now)
outlet_data = switch_config["outlets"].get(index_str)
if outlet_data:
outlet_data["status"] = status
else:
# On error, mark all outlets as unknown
for outlet_data in switch_config["outlets"].values():
outlet_data["status"] = "unknown"
return config
@server.tool()
async def power_action(
switch_id: str,
outlet_id: str,
action: Literal["on", "off", "cycle"],
confirmation: str = "NO",
) -> str:
switch = resolve_switch(switch_id)
index = resolve_outlet(switch, outlet_id)
outlet_config = switch["outlets"].get(index, {})
outlet_type = outlet_config.get("type", "standard")
if outlet_type == "prohibited":
raise PermissionError("Action denied: Outlet is prohibited from MCP control.")
if outlet_type == "critical" and action in ("off", "cycle") and confirmation != "YES":
return "SAFETY LOCK: This is a critical device. Are you sure you want to turn it off? If so, run the command again with confirmation='YES'."
try:
# Run blocking power action in a thread
await asyncio.to_thread(_power_action_sync, switch, int(index)-1, action)
return f"Success: Outlet {outlet_config.get('name', index)} on switch {switch.get('alias', switch_id)} has been turned {action}."
except Exception as e:
logger.exception(f"Failed to perform action {action} on outlet {outlet_id} of switch {switch_id}")
return f"Error: Failed to perform action on outlet. {e}"
@server.tool()
async def group_power_action(target: str, action: Literal["on", "off", "cycle"]) -> str:
config = load_config()
if target not in config["groups"]:
raise ValueError(f"Group not found: {target}")
group = config["groups"].get(target, {})
members = group.get("members", [])
# Pre-flight check for prohibited outlets
for member in members:
switch_id, outlet_id = member.split(":")
switch = resolve_switch(switch_id)
index = resolve_outlet(switch, outlet_id)
if switch["outlets"].get(index, {}).get("type") == "prohibited":
raise PermissionError(f"Action denied: Group action aborted because outlet {outlet_id} on switch {switch_id} is prohibited.")
results = []
errors = []
for member in members:
switch_id, outlet_id = member.split(":")
try:
# Attempt action (power_action is already async/threaded)
result = await power_action(switch_id, outlet_id, action, confirmation="YES")
if result.startswith("Error:"):
errors.append(f"Failed {member}: {result}")
else:
results.append(result)
except Exception as e:
# Capture error but continue to next member
logger.exception(f"Failed group action for {member}")
errors.append(f"Failed {member}: {str(e)}")
if errors:
return f"Partial Success. \nCompleted: {results}\nErrors: {errors}"
return f"Group action '{action}' completed. Results: {results}"
@server.tool()
async def list_outlets(switch_id: str) -> List[Dict[str, Any]]:
"""
Lists all outlets and their status for a given switch.
"""
switch_config = resolve_switch(switch_id)
return await asyncio.to_thread(_list_outlets_sync, switch_config)
@server.tool()
async def sync_config_from_hardware(switch_id: str) -> str:
# Load the entire config first, so we can save it later
async with CONFIG_LOCK:
config = load_config()
switch_config = None
for s in config["switches"]:
if s["alias"] == switch_id or s["ip_address"] == switch_id:
switch_config = s
break
if not switch_config:
raise ValueError("Device not found")
try:
await asyncio.to_thread(_sync_switch_config_sync, switch_config)
save_config(config)
return f"Successfully synchronized outlet names from hardware for switch '{switch_id}'."
except Exception as e:
logger.exception(f"Failed to synchronize switch {switch_id}")
return f"Error: Could not synchronize with hardware. {e}"
@server.tool()
async def add_switch(ip_address: str, username: str, password: str) -> str:
"""
Adds a new DLI power switch to the configuration.
Connects to the switch, syncs its configuration, and saves it.
"""
try:
validate_ip(ip_address)
except ValueError as e:
return str(e)
async with CONFIG_LOCK:
if not os.path.exists(CONFIG_FILE):
dir_name = os.path.dirname(os.path.abspath(CONFIG_FILE))
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
config = {"switches": [], "groups": {}}
else:
config = load_config()
# Generate a default alias and ensure uniqueness
base_alias = ip_address.replace('.', '_')
alias = base_alias
counter = 1
existing_aliases = [s["alias"] for s in config["switches"]]
while alias in existing_aliases:
alias = f"{base_alias}_{counter}"
counter += 1
new_switch_config = {
"alias": alias,
"description": f"DLI Power Switch at {ip_address}",
"ip_address": ip_address,
"username": username,
"password": password,
"outlets": {} # Will be populated by sync_config_from_hardware logic
}
try:
# Now sync the configuration from the hardware using the in-memory object
await asyncio.to_thread(_sync_switch_config_sync, new_switch_config)
# Only add and save if sync was successful (or partial success)
config["switches"].append(new_switch_config)
save_config(config)
return f"Successfully added and synchronized switch '{alias}'."
except Exception as e:
logger.exception(f"Failed to add switch {alias}")
return f"Error: Failed to add switch '{alias}'. {e}. The switch has NOT been added to the configuration."
@server.tool()
async def remove_switch(switch_id: str) -> str:
"""
Removes a DLI power switch from the configuration.
"""
async with CONFIG_LOCK:
config = load_config()
original_count = len(config["switches"])
config["switches"] = [s for s in config["switches"] if s["alias"] != switch_id and s["ip_address"] != switch_id]
if len(config["switches"]) == original_count:
raise ValueError(f"Switch '{switch_id}' not found in configuration.")
save_config(config)
return f"Successfully removed switch '{switch_id}' from the configuration."
@server.tool()
async def update_outlet(
switch_id: str,
outlet_id: str,
new_name: str = None,
new_description: str = None,
new_type: Literal["standard", "critical", "prohibited"] = None,
) -> str:
"""
Updates the definition of an outlet.
"""
async with CONFIG_LOCK:
config = load_config()
switch = None
for s in config["switches"]:
if s["alias"] == switch_id or s["ip_address"] == switch_id:
switch = s
break
if not switch:
raise ValueError("Device not found")
index = resolve_outlet(switch, outlet_id)
outlet_config = switch["outlets"].get(index, {})
if new_name is not None:
try:
await asyncio.to_thread(_update_outlet_name_sync, switch, int(index)-1, new_name)
outlet_config["name"] = new_name
except Exception as e:
logger.exception(f"Failed to update outlet name on hardware for switch {switch_id} outlet {outlet_id}")
return f"Error: Failed to update outlet name on hardware. {e}"
if new_description is not None:
outlet_config["description"] = new_description
if new_type is not None:
outlet_config["type"] = new_type
save_config(config)
return f"Success: Outlet {outlet_id} on switch {switch_id} has been updated."
def create_parser(prog: str, description: str) -> argparse.ArgumentParser:
"""
Creates an ArgumentParser with common configurations for MCP tools.
"""
parser = argparse.ArgumentParser(
prog=prog,
description=description,
formatter_class=argparse.RawTextHelpFormatter
)
return parser
def main():
parser = create_parser(
"DLI Power Switch Tool",
description="A command-line tool and MCP server to control DLI Web Power Switches."
)
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("--mcp-server", action="store_true", help="Run as an MCP server.")
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Inventory command
inventory_parser = subparsers.add_parser("inventory", help="Get inventory of all switches and their outlet statuses.")
# Power action command
power_action_parser = subparsers.add_parser("power_action", help="Perform a power action (on, off, cycle) on a specific outlet.")
power_action_parser.add_argument("switch_id", help="Alias or IP address of the switch.")
power_action_parser.add_argument("outlet_id", help="Index or name of the outlet.")
power_action_parser.add_argument("action", choices=["on", "off", "cycle"], help="Action to perform.")
power_action_parser.add_argument("--confirmation", default="NO", help="Type 'YES' to confirm critical actions.")
# Group power action command
group_power_action_parser = subparsers.add_parser("group_power_action", help="Perform a power action (on, off, cycle) on a group of outlets.")
group_power_action_parser.add_argument("target", help="Name of the group.")
group_power_action_parser.add_argument("action", choices=["on", "off", "cycle"], help="Action to perform.")
# Sync config from hardware command
sync_config_parser = subparsers.add_parser("sync_config_from_hardware", help="Synchronize outlet names from hardware for a specific switch.")
sync_config_parser.add_argument("switch_id", help="Alias or IP address of the switch.")
# List outlets command
list_outlets_parser = subparsers.add_parser("list_outlets", help="List all outlets on a given switch.")
list_outlets_parser.add_argument("switch_id", help="Alias or IP address of the switch.")
# Add switch command
add_switch_parser = subparsers.add_parser("add_switch", help="Adds a new DLI power switch to the configuration.")
add_switch_parser.add_argument("ip_address", help="IP address of the new switch.")
add_switch_parser.add_argument("username", help="Username for the new switch.")
add_switch_parser.add_argument("password", help="Password for the new switch.")
# Remove switch command
remove_switch_parser = subparsers.add_parser("remove_switch", help="Removes a DLI power switch from the configuration.")
remove_switch_parser.add_argument("switch_id", help="Alias or IP address of the switch to remove.")
# Update outlet command
update_outlet_parser = subparsers.add_parser("update_outlet", help="Update the definition of an outlet.")
update_outlet_parser.add_argument("switch_id", help="Alias or IP address of the switch.")
update_outlet_parser.add_argument("outlet_id", help="Index or name of the outlet.")
update_outlet_parser.add_argument("--name", help="New name for the outlet.")
update_outlet_parser.add_argument("--description", help="New description for the outlet.")
update_outlet_parser.add_argument("--type", choices=["standard", "critical", "prohibited"], help="New type for the outlet.")
args = parser.parse_args()
if args.mcp_server or not args.command:
try:
# server.run() handles its own loop
server.run(transport='stdio')
except KeyboardInterrupt:
pass
elif args.command:
try:
if args.command == "inventory":
inventory = asyncio.run(get_inventory())
print(json.dumps(inventory, indent=2))
elif args.command == "power_action":
result = asyncio.run(power_action(args.switch_id, args.outlet_id, args.action, args.confirmation))
print(result)
elif args.command == "group_power_action":
result = asyncio.run(group_power_action(args.target, args.action))
print(result)
elif args.command == "sync_config_from_hardware":
result = asyncio.run(sync_config_from_hardware(args.switch_id))
print(result)
elif args.command == "list_outlets":
outlets = asyncio.run(list_outlets(args.switch_id))
print(json.dumps(outlets, indent=2))
elif args.command == "update_outlet":
result = asyncio.run(update_outlet(args.switch_id, args.outlet_id, new_name=args.name, new_description=args.description, new_type=args.type))
print(result)
elif args.command == "add_switch":
result = asyncio.run(add_switch(args.ip_address, args.username, args.password))
print(result)
elif args.command == "remove_switch":
result = asyncio.run(remove_switch(args.switch_id))
print(result)
except (FileNotFoundError, ValueError, PermissionError) as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()