from config import config
import contextlib
from dataclasses import dataclass
import errno
import json
import logging
from pathlib import Path
from port_discovery import PortDiscovery
import random
import socket
import struct
import threading
import time
from typing import Any, Dict
# Configure logging using settings from config
logging.basicConfig(
level=getattr(logging, config.log_level),
format=config.log_format
)
logger = logging.getLogger("mcp-for-unity-server")
# Module-level lock to guard global connection initialization
_connection_lock = threading.Lock()
# Maximum allowed framed payload size (64 MiB)
FRAMED_MAX = 64 * 1024 * 1024
@dataclass
class UnityConnection:
"""Manages the socket connection to the Unity Editor."""
host: str = config.unity_host
port: int = None # Will be set dynamically
sock: socket.socket = None # Socket for Unity communication
use_framing: bool = False # Negotiated per-connection
def __post_init__(self):
"""Set port from discovery if not explicitly provided"""
if self.port is None:
self.port = PortDiscovery.discover_unity_port()
self._io_lock = threading.Lock()
self._conn_lock = threading.Lock()
def connect(self) -> bool:
"""Establish a connection to the Unity Editor."""
if self.sock:
return True
with self._conn_lock:
if self.sock:
return True
try:
# Bounded connect to avoid indefinite blocking
connect_timeout = float(
getattr(config, "connect_timeout", getattr(config, "connection_timeout", 1.0)))
self.sock = socket.create_connection(
(self.host, self.port), connect_timeout)
# Disable Nagle's algorithm to reduce small RPC latency
with contextlib.suppress(Exception):
self.sock.setsockopt(
socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
logger.debug(f"Connected to Unity at {self.host}:{self.port}")
# Strict handshake: require FRAMING=1
try:
require_framing = getattr(config, "require_framing", True)
timeout = float(getattr(config, "handshake_timeout", 1.0))
self.sock.settimeout(timeout)
buf = bytearray()
deadline = time.monotonic() + timeout
while time.monotonic() < deadline and len(buf) < 512:
try:
chunk = self.sock.recv(256)
if not chunk:
break
buf.extend(chunk)
if b"\n" in buf:
break
except socket.timeout:
break
text = bytes(buf).decode('ascii', errors='ignore').strip()
if 'FRAMING=1' in text:
self.use_framing = True
logger.debug(
'Unity MCP handshake received: FRAMING=1 (strict)')
else:
if require_framing:
# Best-effort plain-text advisory for legacy peers
with contextlib.suppress(Exception):
self.sock.sendall(
b'Unity MCP requires FRAMING=1\n')
raise ConnectionError(
f'Unity MCP requires FRAMING=1, got: {text!r}')
else:
self.use_framing = False
logger.warning(
'Unity MCP handshake missing FRAMING=1; proceeding in legacy mode by configuration')
finally:
self.sock.settimeout(config.connection_timeout)
return True
except Exception as e:
logger.error(f"Failed to connect to Unity: {str(e)}")
try:
if self.sock:
self.sock.close()
except Exception:
pass
self.sock = None
return False
def disconnect(self):
"""Close the connection to the Unity Editor."""
if self.sock:
try:
self.sock.close()
except Exception as e:
logger.error(f"Error disconnecting from Unity: {str(e)}")
finally:
self.sock = None
def _read_exact(self, sock: socket.socket, count: int) -> bytes:
data = bytearray()
while len(data) < count:
chunk = sock.recv(count - len(data))
if not chunk:
raise ConnectionError(
"Connection closed before reading expected bytes")
data.extend(chunk)
return bytes(data)
def receive_full_response(self, sock, buffer_size=config.buffer_size) -> bytes:
"""Receive a complete response from Unity, handling chunked data."""
if self.use_framing:
try:
# Consume heartbeats, but do not hang indefinitely if only zero-length frames arrive
heartbeat_count = 0
deadline = time.monotonic() + getattr(config, 'framed_receive_timeout', 2.0)
while True:
header = self._read_exact(sock, 8)
payload_len = struct.unpack('>Q', header)[0]
if payload_len == 0:
# Heartbeat/no-op frame: consume and continue waiting for a data frame
logger.debug("Received heartbeat frame (length=0)")
heartbeat_count += 1
if heartbeat_count >= getattr(config, 'max_heartbeat_frames', 16) or time.monotonic() > deadline:
# Treat as empty successful response to match C# server behavior
logger.debug(
"Heartbeat threshold reached; returning empty response")
return b""
continue
if payload_len > FRAMED_MAX:
raise ValueError(
f"Invalid framed length: {payload_len}")
payload = self._read_exact(sock, payload_len)
logger.debug(
f"Received framed response ({len(payload)} bytes)")
return payload
except socket.timeout as e:
logger.warning("Socket timeout during framed receive")
raise TimeoutError("Timeout receiving Unity response") from e
except Exception as e:
logger.error(f"Error during framed receive: {str(e)}")
raise
chunks = []
# Respect the socket's currently configured timeout
try:
while True:
chunk = sock.recv(buffer_size)
if not chunk:
if not chunks:
raise Exception(
"Connection closed before receiving data")
break
chunks.append(chunk)
# Process the data received so far
data = b''.join(chunks)
decoded_data = data.decode('utf-8')
# Check if we've received a complete response
try:
# Special case for ping-pong
if decoded_data.strip().startswith('{"status":"success","result":{"message":"pong"'):
logger.debug("Received ping response")
return data
# Handle escaped quotes in the content
if '"content":' in decoded_data:
# Find the content field and its value
content_start = decoded_data.find('"content":') + 9
content_end = decoded_data.rfind('"', content_start)
if content_end > content_start:
# Replace escaped quotes in content with regular quotes
content = decoded_data[content_start:content_end]
content = content.replace('\\"', '"')
decoded_data = decoded_data[:content_start] + \
content + decoded_data[content_end:]
# Validate JSON format
json.loads(decoded_data)
# If we get here, we have valid JSON
logger.info(
f"Received complete response ({len(data)} bytes)")
return data
except json.JSONDecodeError:
# We haven't received a complete valid JSON response yet
continue
except Exception as e:
logger.warning(
f"Error processing response chunk: {str(e)}")
# Continue reading more chunks as this might not be the complete response
continue
except socket.timeout:
logger.warning("Socket timeout during receive")
raise Exception("Timeout receiving Unity response")
except Exception as e:
logger.error(f"Error during receive: {str(e)}")
raise
def send_command(self, command_type: str, params: Dict[str, Any] = None) -> Dict[str, Any]:
"""Send a command with retry/backoff and port rediscovery. Pings only when requested."""
# Defensive guard: catch empty/placeholder invocations early
if not command_type:
raise ValueError("MCP call missing command_type")
if params is None:
# Return a fast, structured error that clients can display without hanging
return {"success": False, "error": "MCP call received with no parameters (client placeholder?)"}
attempts = max(config.max_retries, 5)
base_backoff = max(0.5, config.retry_delay)
def read_status_file() -> dict | None:
try:
status_files = sorted(Path.home().joinpath(
'.unity-mcp').glob('unity-mcp-status-*.json'), key=lambda p: p.stat().st_mtime, reverse=True)
if not status_files:
return None
latest = status_files[0]
with latest.open('r') as f:
return json.load(f)
except Exception:
return None
last_short_timeout = None
# Preflight: if Unity reports reloading, return a structured hint so clients can retry politely
try:
status = read_status_file()
if status and (status.get('reloading') or status.get('reason') == 'reloading'):
return {
"success": False,
"state": "reloading",
"retry_after_ms": int(config.reload_retry_ms),
"error": "Unity domain reload in progress",
"message": "Unity is reloading scripts; please retry shortly"
}
except Exception:
pass
for attempt in range(attempts + 1):
try:
# Ensure connected (handshake occurs within connect())
if not self.sock and not self.connect():
raise Exception("Could not connect to Unity")
# Build payload
if command_type == 'ping':
payload = b'ping'
else:
command = {"type": command_type, "params": params or {}}
payload = json.dumps(
command, ensure_ascii=False).encode('utf-8')
# Send/receive are serialized to protect the shared socket
with self._io_lock:
mode = 'framed' if self.use_framing else 'legacy'
with contextlib.suppress(Exception):
logger.debug(
"send %d bytes; mode=%s; head=%s",
len(payload),
mode,
(payload[:32]).decode('utf-8', 'ignore'),
)
if self.use_framing:
header = struct.pack('>Q', len(payload))
self.sock.sendall(header)
self.sock.sendall(payload)
else:
self.sock.sendall(payload)
# During retry bursts use a short receive timeout and ensure restoration
restore_timeout = None
if attempt > 0 and last_short_timeout is None:
restore_timeout = self.sock.gettimeout()
self.sock.settimeout(1.0)
try:
response_data = self.receive_full_response(self.sock)
with contextlib.suppress(Exception):
logger.debug("recv %d bytes; mode=%s",
len(response_data), mode)
finally:
if restore_timeout is not None:
self.sock.settimeout(restore_timeout)
last_short_timeout = None
# Parse
if command_type == 'ping':
resp = json.loads(response_data.decode('utf-8'))
if resp.get('status') == 'success' and resp.get('result', {}).get('message') == 'pong':
return {"message": "pong"}
raise Exception("Ping unsuccessful")
resp = json.loads(response_data.decode('utf-8'))
if resp.get('status') == 'error':
err = resp.get('error') or resp.get(
'message', 'Unknown Unity error')
raise Exception(err)
return resp.get('result', {})
except Exception as e:
logger.warning(
f"Unity communication attempt {attempt+1} failed: {e}")
try:
if self.sock:
self.sock.close()
finally:
self.sock = None
# Re-discover port each time
try:
new_port = PortDiscovery.discover_unity_port()
if new_port != self.port:
logger.info(
f"Unity port changed {self.port} -> {new_port}")
self.port = new_port
except Exception as de:
logger.debug(f"Port discovery failed: {de}")
if attempt < attempts:
# Heartbeat-aware, jittered backoff
status = read_status_file()
# Base exponential backoff
backoff = base_backoff * (2 ** attempt)
# Decorrelated jitter multiplier
jitter = random.uniform(0.1, 0.3)
# Fast‑retry for transient socket failures
fast_error = isinstance(
e, (ConnectionRefusedError, ConnectionResetError, TimeoutError))
if not fast_error:
try:
err_no = getattr(e, 'errno', None)
fast_error = err_no in (
errno.ECONNREFUSED, errno.ECONNRESET, errno.ETIMEDOUT)
except Exception:
pass
# Cap backoff depending on state
if status and status.get('reloading'):
cap = 0.8
elif fast_error:
cap = 0.25
else:
cap = 3.0
sleep_s = min(cap, jitter * (2 ** attempt))
time.sleep(sleep_s)
continue
raise
# Global Unity connection
_unity_connection = None
def get_unity_connection() -> UnityConnection:
"""Retrieve or establish a persistent Unity connection.
Note: Do NOT ping on every retrieval to avoid connection storms. Rely on
send_command() exceptions to detect broken sockets and reconnect there.
"""
global _unity_connection
if _unity_connection is not None:
return _unity_connection
# Double-checked locking to avoid concurrent socket creation
with _connection_lock:
if _unity_connection is not None:
return _unity_connection
logger.info("Creating new Unity connection")
_unity_connection = UnityConnection()
if not _unity_connection.connect():
_unity_connection = None
raise ConnectionError(
"Could not connect to Unity. Ensure the Unity Editor and MCP Bridge are running.")
logger.info("Connected to Unity on startup")
return _unity_connection
# -----------------------------
# Centralized retry helpers
# -----------------------------
def _is_reloading_response(resp: dict) -> bool:
"""Return True if the Unity response indicates the editor is reloading."""
if not isinstance(resp, dict):
return False
if resp.get("state") == "reloading":
return True
message_text = (resp.get("message") or resp.get("error") or "").lower()
return "reload" in message_text
def send_command_with_retry(command_type: str, params: Dict[str, Any], *, max_retries: int | None = None, retry_ms: int | None = None) -> Dict[str, Any]:
"""Send a command via the shared connection, waiting politely through Unity reloads.
Uses config.reload_retry_ms and config.reload_max_retries by default. Preserves the
structured failure if retries are exhausted.
"""
conn = get_unity_connection()
if max_retries is None:
max_retries = getattr(config, "reload_max_retries", 40)
if retry_ms is None:
retry_ms = getattr(config, "reload_retry_ms", 250)
response = conn.send_command(command_type, params)
retries = 0
while _is_reloading_response(response) and retries < max_retries:
delay_ms = int(response.get("retry_after_ms", retry_ms)
) if isinstance(response, dict) else retry_ms
time.sleep(max(0.0, delay_ms / 1000.0))
retries += 1
response = conn.send_command(command_type, params)
return response
async def async_send_command_with_retry(command_type: str, params: Dict[str, Any], *, loop=None, max_retries: int | None = None, retry_ms: int | None = None) -> Dict[str, Any]:
"""Async wrapper that runs the blocking retry helper in a thread pool."""
try:
import asyncio # local import to avoid mandatory asyncio dependency for sync callers
if loop is None:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
None,
lambda: send_command_with_retry(
command_type, params, max_retries=max_retries, retry_ms=retry_ms),
)
except Exception as e:
# Return a structured error dict for consistency with other responses
return {"success": False, "error": f"Python async retry helper failed: {str(e)}"}