#!/usr/bin/env python3
"""
WebSocket server for distributed Synapse nodes.
Provides real-time pub/sub messaging for Amicus clusters across multiple hosts.
Implements Phase 4 from docs/REALTIME_COMMUNICATION_RESEARCH.md
"""
import asyncio
import json
import time
from typing import Dict, Set, Optional
from dataclasses import dataclass, asdict
import websockets
from websockets.server import WebSocketServerProtocol
@dataclass
class ClientInfo:
"""Information about a connected client"""
node_id: str
connection: WebSocketServerProtocol
connected_at: float
last_heartbeat: float
subscriptions: Set[str]
class SynapseWebSocketServer:
"""
WebSocket server for Amicus Synapse.
Provides real-time pub/sub messaging for distributed nodes.
"""
def __init__(self, host: str = "0.0.0.0", port: int = 8765):
self.host = host
self.port = port
# Client management
self.clients: Dict[WebSocketServerProtocol, ClientInfo] = {}
self.node_lookup: Dict[str, WebSocketServerProtocol] = {}
# Subscriptions (topic -> set of connections)
self.subscriptions: Dict[str, Set[WebSocketServerProtocol]] = {}
# Statistics
self.message_count = 0
self.start_time = time.time()
async def start(self):
"""Start the WebSocket server"""
async with websockets.serve(
self.handle_client,
self.host,
self.port,
ping_interval=30,
ping_timeout=10
):
print(f"✓ WebSocket server listening on ws://{self.host}:{self.port}")
await asyncio.Future() # Run forever
async def handle_client(self, websocket: WebSocketServerProtocol):
"""Handle a client connection"""
client_addr = websocket.remote_address
print(f" Client connected from {client_addr}")
try:
# Wait for registration message
registration = await asyncio.wait_for(
websocket.recv(),
timeout=10.0
)
msg = json.loads(registration)
if msg.get('type') != 'register':
await websocket.close(1008, "First message must be registration")
return
# Register client
node_id = msg.get('node_id')
if not node_id:
await websocket.close(1008, "node_id required")
return
client_info = ClientInfo(
node_id=node_id,
connection=websocket,
connected_at=time.time(),
last_heartbeat=time.time(),
subscriptions=set()
)
self.clients[websocket] = client_info
self.node_lookup[node_id] = websocket
print(f" Registered: {node_id}")
# Send acknowledgment
await websocket.send(json.dumps({
'type': 'registered',
'node_id': node_id,
'server_time': time.time()
}))
# Message loop
async for message in websocket:
await self.handle_message(websocket, message)
except websockets.exceptions.ConnectionClosed:
pass
except asyncio.TimeoutError:
print(f" Registration timeout for {client_addr}")
except Exception as e:
print(f" Error handling client {client_addr}: {e}")
finally:
await self.disconnect_client(websocket)
async def handle_message(self, websocket: WebSocketServerProtocol, message: str):
"""Handle incoming message from client"""
try:
msg = json.loads(message)
msg_type = msg.get('type')
if msg_type == 'subscribe':
await self.handle_subscribe(websocket, msg)
elif msg_type == 'unsubscribe':
await self.handle_unsubscribe(websocket, msg)
elif msg_type == 'publish':
await self.handle_publish(websocket, msg)
elif msg_type == 'query':
await self.handle_query(websocket, msg)
elif msg_type == 'heartbeat':
await self.handle_heartbeat(websocket)
else:
print(f" Unknown message type: {msg_type}")
self.message_count += 1
except json.JSONDecodeError:
client_info = self.clients.get(websocket)
node_id = client_info.node_id if client_info else "unknown"
print(f" Invalid JSON from {node_id}")
except Exception as e:
print(f" Error handling message: {e}")
async def handle_subscribe(self, websocket: WebSocketServerProtocol, msg: dict):
"""Handle subscription request"""
topic = msg.get('topic')
if not topic:
return
# Add to subscription map
if topic not in self.subscriptions:
self.subscriptions[topic] = set()
self.subscriptions[topic].add(websocket)
# Update client info
self.clients[websocket].subscriptions.add(topic)
# Send acknowledgment
await websocket.send(json.dumps({
'type': 'subscribed',
'topic': topic
}))
async def handle_unsubscribe(self, websocket: WebSocketServerProtocol, msg: dict):
"""Handle unsubscription request"""
topic = msg.get('topic')
if not topic:
return
# Remove from subscription map
if topic in self.subscriptions:
self.subscriptions[topic].discard(websocket)
if not self.subscriptions[topic]:
del self.subscriptions[topic]
# Update client info
self.clients[websocket].subscriptions.discard(topic)
# Send acknowledgment
await websocket.send(json.dumps({
'type': 'unsubscribed',
'topic': topic
}))
async def handle_publish(self, websocket: WebSocketServerProtocol, msg: dict):
"""Handle publish request"""
topic = msg.get('topic')
data = msg.get('data', {})
if not topic:
return
# Get publisher info
client_info = self.clients[websocket]
# Broadcast to all subscribers
if topic in self.subscriptions:
message = json.dumps({
'type': 'event',
'topic': topic,
'data': data,
'publisher': client_info.node_id,
'timestamp': time.time()
})
# Send to all subscribers (except publisher)
tasks = []
for subscriber in self.subscriptions[topic]:
if subscriber != websocket:
tasks.append(subscriber.send(message))
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
async def handle_query(self, websocket: WebSocketServerProtocol, msg: dict):
"""Handle query request (get current state)"""
query_type = msg.get('query')
request_id = msg.get('request_id')
response = {
'type': 'query_response',
'request_id': request_id
}
if query_type == 'nodes':
# Return list of connected nodes
response['data'] = [
{'node_id': info.node_id, 'connected_at': info.connected_at}
for info in self.clients.values()
]
elif query_type == 'stats':
# Return server statistics
uptime = time.time() - self.start_time
response['data'] = {
'uptime': uptime,
'client_count': len(self.clients),
'message_count': self.message_count,
'subscriptions': {
topic: len(subs)
for topic, subs in self.subscriptions.items()
}
}
await websocket.send(json.dumps(response))
async def handle_heartbeat(self, websocket: WebSocketServerProtocol):
"""Handle heartbeat message"""
if websocket in self.clients:
self.clients[websocket].last_heartbeat = time.time()
await websocket.send(json.dumps({
'type': 'heartbeat_ack',
'server_time': time.time()
}))
async def disconnect_client(self, websocket: WebSocketServerProtocol):
"""Clean up when client disconnects"""
if websocket not in self.clients:
return
client_info = self.clients[websocket]
node_id = client_info.node_id
# Remove from all subscriptions
for topic in client_info.subscriptions:
if topic in self.subscriptions:
self.subscriptions[topic].discard(websocket)
if not self.subscriptions[topic]:
del self.subscriptions[topic]
# Remove from client maps
del self.clients[websocket]
if node_id in self.node_lookup:
del self.node_lookup[node_id]
print(f" Disconnected: {node_id}")
def main():
"""Main entry point for standalone server"""
import argparse
parser = argparse.ArgumentParser(description='Amicus WebSocket Server')
parser.add_argument('--host', default='0.0.0.0', help='Host to bind to')
parser.add_argument('--port', type=int, default=8765, help='Port to listen on')
args = parser.parse_args()
server = SynapseWebSocketServer(host=args.host, port=args.port)
try:
asyncio.run(server.start())
except KeyboardInterrupt:
print("\nShutting down...")
if __name__ == '__main__':
main()