websocket.py•5.92 kB
"""
WebSocket transport for real-time bidirectional ACP communication.
"""
import asyncio
import json
import uuid
from typing import AsyncGenerator, Dict, Set, Coroutine, Any
import logging
try:
import websockets
from websockets import ServerConnection
from websockets.exceptions import ConnectionClosed
WEBSOCKETS_AVAILABLE = True
except ImportError:
WEBSOCKETS_AVAILABLE = False
ServerConnection = None
ConnectionClosed = None
from .base import BaseTransport, TransportMessage
logger = logging.getLogger(__name__)
class WebSocketTransport(BaseTransport):
"""WebSocket transport for real-time bidirectional communication."""
def __init__(self, config: Dict[str, Any] | None = None):
super().__init__(config)
if not WEBSOCKETS_AVAILABLE:
raise ImportError("websockets package required for WebSocket transport")
self.connections: Dict[str, ServerConnection] = {}
self.message_queues: Dict[str, asyncio.Queue] = {}
self.host = self.config.get("host", "localhost")
self.port = self.config.get("port", 49153)
self.server = None
self.connection_id = f"ws-{uuid.uuid4().hex[:8]}"
async def start(self) -> None:
"""Start WebSocket server."""
self.server = await websockets.serve(
self.handle_client,
self.host,
self.port
)
self.is_running = True
logger.info(f"WebSocket transport started on ws://{self.host}:{self.port}")
async def stop(self) -> None:
"""Stop WebSocket server."""
self.is_running = False
# Close all connections
for conn_id, websocket in list(self.connections.items()):
try:
await websocket.close()
except:
pass
self.connections.clear()
self.message_queues.clear()
if self.server:
self.server.close()
await self.server.wait_closed()
async def send_message(self, message: TransportMessage) -> None:
"""Send message to all connected WebSocket clients."""
if not self.is_running:
return
data = message.model_dump_json(exclude_none=True)
disconnected = []
for conn_id, websocket in self.connections.items():
try:
await websocket.send(data)
except Exception as e:
logger.warning(f"Failed to send to WebSocket client {conn_id}: {e}")
disconnected.append(conn_id)
# Clean up disconnected clients
for conn_id in disconnected:
self.connections.pop(conn_id, None)
self.message_queues.pop(conn_id, None)
async def receive_messages(self) -> AsyncGenerator[TransportMessage, None]:
"""Receive messages from WebSocket clients."""
if not self.is_running:
return
while self.is_running:
for queue in list(self.message_queues.values()):
try:
if not queue.empty():
message = await queue.get()
yield message
except asyncio.QueueEmpty:
continue
await asyncio.sleep(0.1)
async def handle_client(self, websocket: ServerConnection):
"""Handle new WebSocket connection."""
conn_id = f"ws-{uuid.uuid4().hex[:8]}"
self.connections[conn_id] = websocket
self.message_queues[conn_id] = asyncio.Queue()
logger.info(f"WebSocket client connected: {conn_id}")
try:
# Send welcome message
welcome_msg = TransportMessage(
id=uuid.uuid4().hex,
method="connection.established",
params={"connection_id": conn_id, "transport": "websocket"}
)
await websocket.send(welcome_msg.model_dump_json(exclude_none=True))
# Handle messages from this client
async for message in websocket:
try:
data = json.loads(message)
transport_msg = TransportMessage(**data)
await self.message_queues[conn_id].put(transport_msg)
except json.JSONDecodeError:
logger.warning(f"Invalid JSON from {conn_id}: {message}")
except Exception as e:
logger.error(f"Message handling error for {conn_id}: {e}")
except ConnectionClosed:
logger.info(f"WebSocket client disconnected: {conn_id}")
except Exception as e:
logger.warning(f"WebSocket connection error for {conn_id}: {e}")
finally:
self.connections.pop(conn_id, None)
self.message_queues.pop(conn_id, None)
async def send_to_client(self, conn_id: str, message: TransportMessage):
"""Send message to specific WebSocket client."""
if conn_id in self.connections:
try:
data = message.model_dump_json(exclude_none=True)
await self.connections[conn_id].send(data)
except Exception as e:
logger.warning(f"Failed to send to WebSocket client {conn_id}: {e}")
# Clean up disconnected client
self.connections.pop(conn_id, None)
self.message_queues.pop(conn_id, None)
def get_connection_info(self) -> Dict[str, Any]:
"""Get information about active connections."""
return {
"transport": "websocket",
"running": self.is_running,
"connections": len(self.connections),
"host": self.host,
"port": self.port,
"connection_ids": list(self.connections.keys())
}