Hass-MCP
by voska
Verified
- hass-mcp
- app
import httpx
from typing import Dict, Any, Optional, List, TypeVar, Callable, Awaitable, Union, cast
import functools
import inspect
import logging
from app.config import HA_URL, HA_TOKEN, get_ha_headers
# Set up logging
logger = logging.getLogger(__name__)
# Define a generic type for our API function return values
T = TypeVar('T')
F = TypeVar('F', bound=Callable[..., Awaitable[Any]])
# HTTP client
_client: Optional[httpx.AsyncClient] = None
# Default field sets for different verbosity levels
# Lean fields for standard requests (optimized for token efficiency)
DEFAULT_LEAN_FIELDS = ["entity_id", "state", "attr.friendly_name"]
# Common fields that are typically needed for entity operations
DEFAULT_STANDARD_FIELDS = ["entity_id", "state", "attributes", "last_updated"]
# Domain-specific important attributes to include in lean responses
DOMAIN_IMPORTANT_ATTRIBUTES = {
"light": ["brightness", "color_temp", "rgb_color", "supported_color_modes"],
"switch": ["device_class"],
"binary_sensor": ["device_class"],
"sensor": ["device_class", "unit_of_measurement", "state_class"],
"climate": ["hvac_mode", "current_temperature", "temperature", "hvac_action"],
"media_player": ["media_title", "media_artist", "source", "volume_level"],
"cover": ["current_position", "current_tilt_position"],
"fan": ["percentage", "preset_mode"],
"camera": ["entity_picture"],
"automation": ["last_triggered"],
"scene": [],
"script": ["last_triggered"],
}
def handle_api_errors(func: F) -> F:
"""
Decorator to handle common error cases for Home Assistant API calls
Args:
func: The async function to decorate
Returns:
Wrapped function that handles errors
"""
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
# Determine return type from function annotation
return_type = inspect.signature(func).return_annotation
is_dict_return = 'Dict' in str(return_type)
is_list_return = 'List' in str(return_type)
# Prepare error formatters based on return type
def format_error(msg: str) -> Any:
if is_dict_return:
return {"error": msg}
elif is_list_return:
return [{"error": msg}]
else:
return msg
try:
# Check if token is available
if not HA_TOKEN:
return format_error("No Home Assistant token provided. Please set HA_TOKEN in .env file.")
# Call the original function
return await func(*args, **kwargs)
except httpx.ConnectError:
return format_error(f"Connection error: Cannot connect to Home Assistant at {HA_URL}")
except httpx.TimeoutException:
return format_error(f"Timeout error: Home Assistant at {HA_URL} did not respond in time")
except httpx.HTTPStatusError as e:
return format_error(f"HTTP error: {e.response.status_code} - {e.response.reason_phrase}")
except httpx.RequestError as e:
return format_error(f"Error connecting to Home Assistant: {str(e)}")
except Exception as e:
return format_error(f"Unexpected error: {str(e)}")
return cast(F, wrapper)
# Persistent HTTP client
async def get_client() -> httpx.AsyncClient:
"""Get a persistent httpx client for Home Assistant API calls"""
global _client
if _client is None:
logger.debug("Creating new HTTP client")
_client = httpx.AsyncClient(timeout=10.0)
return _client
async def cleanup_client() -> None:
"""Close the HTTP client when shutting down"""
global _client
if _client:
logger.debug("Closing HTTP client")
await _client.aclose()
_client = None
# Direct entity retrieval function
async def get_all_entity_states() -> Dict[str, Dict[str, Any]]:
"""Fetch all entity states from Home Assistant"""
client = await get_client()
response = await client.get(f"{HA_URL}/api/states", headers=get_ha_headers())
response.raise_for_status()
entities = response.json()
# Create a mapping for easier access
return {entity["entity_id"]: entity for entity in entities}
def filter_fields(data: Dict[str, Any], fields: List[str]) -> Dict[str, Any]:
"""
Filter entity data to only include requested fields
This function helps reduce token usage by returning only requested fields.
Args:
data: The complete entity data dictionary
fields: List of fields to include in the result
- "state": Include the entity state
- "attributes": Include all attributes
- "attr.X": Include only attribute X (e.g. "attr.brightness")
- "context": Include context data
- "last_updated"/"last_changed": Include timestamp fields
Returns:
A filtered dictionary with only the requested fields
"""
if not fields:
return data
result = {"entity_id": data["entity_id"]}
for field in fields:
if field == "state":
result["state"] = data.get("state")
elif field == "attributes":
result["attributes"] = data.get("attributes", {})
elif field.startswith("attr.") and len(field) > 5:
attr_name = field[5:]
attributes = data.get("attributes", {})
if attr_name in attributes:
if "attributes" not in result:
result["attributes"] = {}
result["attributes"][attr_name] = attributes[attr_name]
elif field == "context":
if "context" in data:
result["context"] = data["context"]
elif field in ["last_updated", "last_changed"]:
if field in data:
result[field] = data[field]
return result
# API Functions
@handle_api_errors
async def get_hass_version() -> str:
"""Get the Home Assistant version from the API"""
client = await get_client()
response = await client.get(f"{HA_URL}/api/config", headers=get_ha_headers())
response.raise_for_status()
data = response.json()
return data.get("version", "unknown")
@handle_api_errors
async def get_entity_state(
entity_id: str,
fields: Optional[List[str]] = None,
lean: bool = False
) -> Dict[str, Any]:
"""
Get the state of a Home Assistant entity
Args:
entity_id: The entity ID to get
fields: Optional list of specific fields to include in the response
lean: If True, returns a token-efficient version with minimal fields
(overridden by fields parameter if provided)
Returns:
Entity state dictionary, optionally filtered to include only specified fields
"""
# Fetch directly
client = await get_client()
response = await client.get(
f"{HA_URL}/api/states/{entity_id}",
headers=get_ha_headers()
)
response.raise_for_status()
entity_data = response.json()
# Apply field filtering if requested
if fields:
# User-specified fields take precedence
return filter_fields(entity_data, fields)
elif lean:
# Build domain-specific lean fields
lean_fields = DEFAULT_LEAN_FIELDS.copy()
# Add domain-specific important attributes
domain = entity_id.split('.')[0]
if domain in DOMAIN_IMPORTANT_ATTRIBUTES:
for attr in DOMAIN_IMPORTANT_ATTRIBUTES[domain]:
lean_fields.append(f"attr.{attr}")
return filter_fields(entity_data, lean_fields)
else:
# Return full entity data
return entity_data
@handle_api_errors
async def get_entities(
domain: Optional[str] = None,
search_query: Optional[str] = None,
limit: int = 100,
fields: Optional[List[str]] = None,
lean: bool = True
) -> List[Dict[str, Any]]:
"""
Get a list of all entities from Home Assistant with optional filtering and search
Args:
domain: Optional domain to filter entities by (e.g., 'light', 'switch')
search_query: Optional case-insensitive search term to filter by entity_id, friendly_name or other attributes
limit: Maximum number of entities to return (default: 100)
fields: Optional list of specific fields to include in each entity
lean: If True (default), returns token-efficient versions with minimal fields
Returns:
List of entity dictionaries, optionally filtered by domain and search terms,
and optionally limited to specific fields
"""
# Get all entities directly
client = await get_client()
response = await client.get(f"{HA_URL}/api/states", headers=get_ha_headers())
response.raise_for_status()
entities = response.json()
# Filter by domain if specified
if domain:
entities = [entity for entity in entities if entity["entity_id"].startswith(f"{domain}.")]
# Search if query is provided
if search_query and search_query.strip():
search_term = search_query.lower().strip()
filtered_entities = []
for entity in entities:
# Search in entity_id
if search_term in entity["entity_id"].lower():
filtered_entities.append(entity)
continue
# Search in friendly_name
friendly_name = entity.get("attributes", {}).get("friendly_name", "").lower()
if friendly_name and search_term in friendly_name:
filtered_entities.append(entity)
continue
# Search in other common attributes (state, area_id, etc.)
if search_term in entity.get("state", "").lower():
filtered_entities.append(entity)
continue
# Search in other attributes
for attr_name, attr_value in entity.get("attributes", {}).items():
# Check if attribute value can be converted to string
if isinstance(attr_value, (str, int, float, bool)):
if search_term in str(attr_value).lower():
filtered_entities.append(entity)
break
entities = filtered_entities
# Apply the limit
if limit > 0 and len(entities) > limit:
entities = entities[:limit]
# Apply field filtering if requested
if fields:
# Use explicit field list when provided
return [filter_fields(entity, fields) for entity in entities]
elif lean:
# Apply domain-specific lean fields to each entity
result = []
for entity in entities:
# Get the entity's domain
entity_domain = entity["entity_id"].split('.')[0]
# Start with basic lean fields
lean_fields = DEFAULT_LEAN_FIELDS.copy()
# Add domain-specific important attributes
if entity_domain in DOMAIN_IMPORTANT_ATTRIBUTES:
for attr in DOMAIN_IMPORTANT_ATTRIBUTES[entity_domain]:
lean_fields.append(f"attr.{attr}")
# Filter and add to result
result.append(filter_fields(entity, lean_fields))
return result
else:
# Return full entities
return entities
@handle_api_errors
async def call_service(domain: str, service: str, data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Call a Home Assistant service"""
if data is None:
data = {}
client = await get_client()
response = await client.post(
f"{HA_URL}/api/services/{domain}/{service}",
headers=get_ha_headers(),
json=data
)
response.raise_for_status()
# Invalidate cache after service calls as they might change entity states
global _entities_timestamp
_entities_timestamp = 0
return response.json()
@handle_api_errors
async def summarize_domain(domain: str, example_limit: int = 3) -> Dict[str, Any]:
"""
Generate a summary of entities in a domain
Args:
domain: The domain to summarize (e.g., 'light', 'switch')
example_limit: Maximum number of examples to include for each state
Returns:
Dictionary with summary information
"""
entities = await get_entities(domain=domain)
# Check if we got an error response
if isinstance(entities, dict) and "error" in entities:
return entities # Just pass through the error
try:
# Initialize summary data
total_count = len(entities)
state_counts = {}
state_examples = {}
attributes_summary = {}
# Process entities to build the summary
for entity in entities:
state = entity.get("state", "unknown")
# Count states
if state not in state_counts:
state_counts[state] = 0
state_examples[state] = []
state_counts[state] += 1
# Add examples (up to the limit)
if len(state_examples[state]) < example_limit:
example = {
"entity_id": entity["entity_id"],
"friendly_name": entity.get("attributes", {}).get("friendly_name", entity["entity_id"])
}
state_examples[state].append(example)
# Collect attribute keys for summary
for attr_key in entity.get("attributes", {}):
if attr_key not in attributes_summary:
attributes_summary[attr_key] = 0
attributes_summary[attr_key] += 1
# Create the summary
summary = {
"domain": domain,
"total_count": total_count,
"state_distribution": state_counts,
"examples": state_examples,
"common_attributes": sorted(
[(k, v) for k, v in attributes_summary.items()],
key=lambda x: x[1],
reverse=True
)[:10] # Top 10 most common attributes
}
return summary
except Exception as e:
return {"error": f"Error generating domain summary: {str(e)}"}
@handle_api_errors
async def get_automations() -> List[Dict[str, Any]]:
"""Get a list of all automations from Home Assistant"""
# Reuse the get_entities function with domain filtering
automation_entities = await get_entities(domain="automation")
# Check if we got an error response
if isinstance(automation_entities, dict) and "error" in automation_entities:
return automation_entities # Just pass through the error
# Process automation entities
result = []
try:
for entity in automation_entities:
# Extract relevant information
automation_info = {
"id": entity["entity_id"].split(".")[1],
"entity_id": entity["entity_id"],
"state": entity["state"],
"alias": entity["attributes"].get("friendly_name", entity["entity_id"]),
}
# Add any additional attributes that might be useful
if "last_triggered" in entity["attributes"]:
automation_info["last_triggered"] = entity["attributes"]["last_triggered"]
result.append(automation_info)
except (TypeError, KeyError) as e:
# Handle errors in processing the entities
return {"error": f"Error processing automation entities: {str(e)}"}
return result
@handle_api_errors
async def reload_automations() -> Dict[str, Any]:
"""Reload all automations in Home Assistant"""
return await call_service("automation", "reload", {})
@handle_api_errors
async def restart_home_assistant() -> Dict[str, Any]:
"""Restart Home Assistant"""
return await call_service("homeassistant", "restart", {})
@handle_api_errors
async def get_hass_error_log() -> Dict[str, Any]:
"""
Get the Home Assistant error log for troubleshooting
Returns:
A dictionary containing:
- log_text: The full error log text
- error_count: Number of ERROR entries found
- warning_count: Number of WARNING entries found
- integration_mentions: Map of integration names to mention counts
- error: Error message if retrieval failed
"""
try:
# Call the Home Assistant API error_log endpoint
url = f"{HA_URL}/api/error_log"
headers = get_ha_headers()
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers, timeout=30)
if response.status_code == 200:
log_text = response.text
# Count errors and warnings
error_count = log_text.count("ERROR")
warning_count = log_text.count("WARNING")
# Extract integration mentions
import re
integration_mentions = {}
# Look for patterns like [mqtt], [zwave], etc.
for match in re.finditer(r'\[([a-zA-Z0-9_]+)\]', log_text):
integration = match.group(1).lower()
if integration not in integration_mentions:
integration_mentions[integration] = 0
integration_mentions[integration] += 1
return {
"log_text": log_text,
"error_count": error_count,
"warning_count": warning_count,
"integration_mentions": integration_mentions
}
else:
return {
"error": f"Error retrieving error log: {response.status_code} {response.reason_phrase}",
"details": response.text,
"log_text": "",
"error_count": 0,
"warning_count": 0,
"integration_mentions": {}
}
except Exception as e:
logger.error(f"Error retrieving Home Assistant error log: {str(e)}")
return {
"error": f"Error retrieving error log: {str(e)}",
"log_text": "",
"error_count": 0,
"warning_count": 0,
"integration_mentions": {}
}
@handle_api_errors
async def get_system_overview() -> Dict[str, Any]:
"""
Get a comprehensive overview of the entire Home Assistant system
Returns:
A dictionary containing:
- total_entities: Total count of all entities
- domains: Dictionary of domains with their entity counts and state distributions
- domain_samples: Representative sample entities for each domain (2-3 per domain)
- domain_attributes: Common attributes for each domain
- area_distribution: Entities grouped by area (if available)
"""
try:
# Get ALL entities with minimal fields for efficiency
# We retrieve all entities since API calls don't consume tokens, only responses do
client = await get_client()
response = await client.get(f"{HA_URL}/api/states", headers=get_ha_headers())
response.raise_for_status()
all_entities_raw = response.json()
# Apply lean formatting to reduce token usage in the response
all_entities = []
for entity in all_entities_raw:
domain = entity["entity_id"].split(".")[0]
# Start with basic lean fields
lean_fields = ["entity_id", "state", "attr.friendly_name"]
# Add domain-specific important attributes
if domain in DOMAIN_IMPORTANT_ATTRIBUTES:
for attr in DOMAIN_IMPORTANT_ATTRIBUTES[domain]:
lean_fields.append(f"attr.{attr}")
# Filter and add to result
all_entities.append(filter_fields(entity, lean_fields))
# Initialize overview structure
overview = {
"total_entities": len(all_entities),
"domains": {},
"domain_samples": {},
"domain_attributes": {},
"area_distribution": {}
}
# Group entities by domain
domain_entities = {}
for entity in all_entities:
domain = entity["entity_id"].split(".")[0]
if domain not in domain_entities:
domain_entities[domain] = []
domain_entities[domain].append(entity)
# Process each domain
for domain, entities in domain_entities.items():
# Count entities in this domain
count = len(entities)
# Collect state distribution
state_distribution = {}
for entity in entities:
state = entity.get("state", "unknown")
if state not in state_distribution:
state_distribution[state] = 0
state_distribution[state] += 1
# Store domain information
overview["domains"][domain] = {
"count": count,
"states": state_distribution
}
# Select representative samples (2-3 per domain)
sample_limit = min(3, count)
samples = []
for i in range(sample_limit):
entity = entities[i]
samples.append({
"entity_id": entity["entity_id"],
"state": entity.get("state", "unknown"),
"friendly_name": entity.get("attributes", {}).get("friendly_name", entity["entity_id"])
})
overview["domain_samples"][domain] = samples
# Collect common attributes for this domain
attribute_counts = {}
for entity in entities:
for attr in entity.get("attributes", {}):
if attr not in attribute_counts:
attribute_counts[attr] = 0
attribute_counts[attr] += 1
# Get top 5 most common attributes for this domain
common_attributes = sorted(attribute_counts.items(), key=lambda x: x[1], reverse=True)[:5]
overview["domain_attributes"][domain] = [attr for attr, count in common_attributes]
# Group by area if available
for entity in entities:
area_id = entity.get("attributes", {}).get("area_id", "Unknown")
area_name = entity.get("attributes", {}).get("area_name", area_id)
if area_name not in overview["area_distribution"]:
overview["area_distribution"][area_name] = {}
if domain not in overview["area_distribution"][area_name]:
overview["area_distribution"][area_name][domain] = 0
overview["area_distribution"][area_name][domain] += 1
# Add summary information
overview["domain_count"] = len(domain_entities)
overview["most_common_domains"] = sorted(
[(domain, len(entities)) for domain, entities in domain_entities.items()],
key=lambda x: x[1],
reverse=True
)[:5]
return overview
except Exception as e:
logger.error(f"Error generating system overview: {str(e)}")
return {"error": f"Error generating system overview: {str(e)}"}