websocket_server.py•16.3 kB
"""WebSocket 서버 구현"""
import asyncio
import json
import time
import logging
from datetime import datetime
from typing import Dict, List, Set, Any, Optional, Callable
import websockets
from websockets import WebSocketServerProtocol
from websockets.exceptions import ConnectionClosed, WebSocketException
class WebSocketServer:
"""WebSocket 서버 클래스"""
def __init__(self, config: Dict[str, Any] = None):
"""
Args:
config: 서버 설정 딕셔너리
"""
if config is None:
config = {}
# 설정값 추출
self.host = config.get("host", "localhost")
self.port = config.get("port", 8765)
self.max_connections = config.get("max_connections", 1000)
self.heartbeat_interval = config.get("heartbeat_interval", 30)
self.compression = config.get("compression", False)
self.compression_enabled = self.compression # 테스트 호환성을 위한 별칭
self.authentication = config.get("authentication", False)
self.authentication_enabled = self.authentication # 테스트 호환성을 위한 별칭
# 연결 관리
self.connections: Set[WebSocketServerProtocol] = set()
self.subscriptions: Dict[str, Set[WebSocketServerProtocol]] = {}
self.connection_info: Dict[WebSocketServerProtocol, Dict] = {}
# 서버 상태
self.server = None
self.is_running = False
self.start_time = None
# 통계
self.stats = {
"total_connections": 0,
"active_connections": 0,
"messages_sent": 0,
"messages_received": 0,
"subscriptions_count": 0,
"errors": 0
}
# 콜백
self.on_connect: Optional[Callable] = None
self.on_disconnect: Optional[Callable] = None
self.on_message: Optional[Callable] = None
self.on_error: Optional[Callable] = None
# 백그라운드 태스크
self._heartbeat_task: Optional[asyncio.Task] = None
# 로깅
self.logger = logging.getLogger(__name__)
@property
def connection_count(self) -> int:
"""현재 연결 수"""
return len(self.connections)
async def start(self) -> bool:
"""서버 시작"""
try:
if self.is_running:
return True
self.server = await websockets.serve(
self._handle_connection,
self.host,
self.port,
max_size=None,
max_queue=None
)
self.is_running = True
self.start_time = time.time()
# 하트비트 태스크 시작
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
self.logger.info(f"WebSocket server started on {self.host}:{self.port}")
return True
except Exception as e:
self.logger.error(f"Failed to start server: {e}")
self.stats["errors"] += 1
if self.on_error:
await self._safe_callback(self.on_error, "start_error", str(e))
return False
async def stop(self) -> bool:
"""서버 중지"""
try:
if not self.is_running:
return True
self.is_running = False
# 하트비트 태스크 중지
if self._heartbeat_task:
self._heartbeat_task.cancel()
try:
await self._heartbeat_task
except asyncio.CancelledError:
pass
self._heartbeat_task = None
# 모든 연결 종료
if self.connections:
await asyncio.gather(
*[conn.close() for conn in self.connections.copy()],
return_exceptions=True
)
# 서버 종료
if self.server:
self.server.close()
await self.server.wait_closed()
self.server = None
self.logger.info("WebSocket server stopped")
return True
except Exception as e:
self.logger.error(f"Error stopping server: {e}")
self.stats["errors"] += 1
return False
async def broadcast(self, message: Dict[str, Any], channel: str = None) -> int:
"""메시지 브로드캐스트"""
if not self.is_running:
return 0
try:
message_str = json.dumps(message, ensure_ascii=False, default=str)
sent_count = 0
failed_connections = set()
# 대상 연결 결정
if channel and channel in self.subscriptions:
targets = self.subscriptions[channel]
else:
targets = self.connections.copy()
# 메시지 전송
for connection in targets:
try:
await connection.send(message_str)
sent_count += 1
except (ConnectionClosed, WebSocketException):
failed_connections.add(connection)
except Exception as e:
self.logger.error(f"Error sending message: {e}")
failed_connections.add(connection)
# 실패한 연결 정리
for conn in failed_connections:
await self._remove_connection(conn)
self.stats["messages_sent"] += sent_count
return sent_count
except Exception as e:
self.logger.error(f"Broadcast error: {e}")
self.stats["errors"] += 1
return 0
async def send_to_connection(self, connection: WebSocketServerProtocol,
message: Dict[str, Any]) -> bool:
"""특정 연결에 메시지 전송"""
try:
if connection not in self.connections:
return False
message_str = json.dumps(message, ensure_ascii=False, default=str)
await connection.send(message_str)
self.stats["messages_sent"] += 1
return True
except (ConnectionClosed, WebSocketException):
await self._remove_connection(connection)
return False
except Exception as e:
self.logger.error(f"Error sending to connection: {e}")
self.stats["errors"] += 1
return False
def get_connections(self) -> List[Dict[str, Any]]:
"""연결 목록 조회"""
result = []
for connection in self.connections:
info = self.connection_info.get(connection, {})
result.append({
"id": id(connection),
"remote_address": getattr(connection, 'remote_address', None),
"connect_time": info.get("connect_time"),
"last_ping": info.get("last_ping"),
"subscriptions": [ch for ch, subs in self.subscriptions.items()
if connection in subs]
})
return result
def get_subscriptions(self) -> Dict[str, int]:
"""구독 현황 조회"""
return {channel: len(connections) for channel, connections in self.subscriptions.items()}
def get_stats(self) -> Dict[str, Any]:
"""서버 통계 조회"""
uptime = time.time() - self.start_time if self.start_time else 0
return {
**self.stats,
"active_connections": len(self.connections),
"subscriptions_count": len(self.subscriptions),
"uptime_seconds": uptime,
"is_running": self.is_running,
"server_info": {
"host": self.host,
"port": self.port,
"max_connections": self.max_connections
}
}
async def add_subscription(self, connection: WebSocketServerProtocol, channel: str) -> bool:
"""구독 추가"""
try:
if connection not in self.connections:
return False
if channel not in self.subscriptions:
self.subscriptions[channel] = set()
self.subscriptions[channel].add(connection)
return True
except Exception as e:
self.logger.error(f"Error adding subscription: {e}")
return False
async def remove_subscription(self, connection: WebSocketServerProtocol, channel: str) -> bool:
"""구독 제거"""
try:
if channel in self.subscriptions and connection in self.subscriptions[channel]:
self.subscriptions[channel].remove(connection)
# 빈 채널 정리
if not self.subscriptions[channel]:
del self.subscriptions[channel]
return True
return False
except Exception as e:
self.logger.error(f"Error removing subscription: {e}")
return False
async def _handle_connection(self, websocket: WebSocketServerProtocol, path: str):
"""연결 처리"""
if len(self.connections) >= self.max_connections:
await websocket.close(code=1013, reason="Too many connections")
return
# 연결 추가
self.connections.add(websocket)
self.connection_info[websocket] = {
"connect_time": time.time(),
"last_ping": time.time(),
"path": path
}
self.stats["total_connections"] += 1
self.stats["active_connections"] = len(self.connections)
try:
# 연결 콜백
if self.on_connect:
await self._safe_callback(self.on_connect, websocket, path)
# 메시지 처리 루프
async for message in websocket:
await self._handle_message(websocket, message)
except ConnectionClosed:
pass
except Exception as e:
self.logger.error(f"Connection error: {e}")
self.stats["errors"] += 1
if self.on_error:
await self._safe_callback(self.on_error, "connection_error", str(e), websocket)
finally:
await self._remove_connection(websocket)
async def _handle_message(self, websocket: WebSocketServerProtocol, message: str):
"""메시지 처리"""
try:
self.stats["messages_received"] += 1
# JSON 파싱
try:
data = json.loads(message)
except json.JSONDecodeError:
await self._send_error(websocket, "Invalid JSON")
return
# 메시지 타입 처리
msg_type = data.get("type")
if msg_type == "ping":
await self._handle_ping(websocket, data)
elif msg_type == "subscribe":
await self._handle_subscribe(websocket, data)
elif msg_type == "unsubscribe":
await self._handle_unsubscribe(websocket, data)
else:
# 사용자 정의 메시지 콜백
if self.on_message:
await self._safe_callback(self.on_message, websocket, data)
except Exception as e:
self.logger.error(f"Message handling error: {e}")
self.stats["errors"] += 1
await self._send_error(websocket, "Message processing error")
async def _handle_ping(self, websocket: WebSocketServerProtocol, data: Dict):
"""핑 처리"""
self.connection_info[websocket]["last_ping"] = time.time()
await self.send_to_connection(websocket, {
"type": "pong",
"timestamp": time.time()
})
async def _handle_subscribe(self, websocket: WebSocketServerProtocol, data: Dict):
"""구독 처리"""
channel = data.get("channel")
if not channel:
await self._send_error(websocket, "Channel required")
return
success = await self.add_subscription(websocket, channel)
await self.send_to_connection(websocket, {
"type": "subscribe_response",
"channel": channel,
"success": success
})
async def _handle_unsubscribe(self, websocket: WebSocketServerProtocol, data: Dict):
"""구독 해제 처리"""
channel = data.get("channel")
if not channel:
await self._send_error(websocket, "Channel required")
return
success = await self.remove_subscription(websocket, channel)
await self.send_to_connection(websocket, {
"type": "unsubscribe_response",
"channel": channel,
"success": success
})
async def _send_error(self, websocket: WebSocketServerProtocol, error: str):
"""에러 메시지 전송"""
await self.send_to_connection(websocket, {
"type": "error",
"message": error,
"timestamp": time.time()
})
async def _remove_connection(self, websocket: WebSocketServerProtocol):
"""연결 제거"""
try:
# 연결 세트에서 제거
self.connections.discard(websocket)
# 모든 구독에서 제거
for channel in list(self.subscriptions.keys()):
if websocket in self.subscriptions[channel]:
self.subscriptions[channel].remove(websocket)
if not self.subscriptions[channel]:
del self.subscriptions[channel]
# 연결 정보 제거
self.connection_info.pop(websocket, None)
# 통계 업데이트
self.stats["active_connections"] = len(self.connections)
# 연결 해제 콜백
if self.on_disconnect:
await self._safe_callback(self.on_disconnect, websocket)
except Exception as e:
self.logger.error(f"Error removing connection: {e}")
async def _heartbeat_loop(self):
"""하트비트 루프"""
while self.is_running:
try:
current_time = time.time()
dead_connections = set()
# 오래된 연결 확인
for connection in self.connections.copy():
info = self.connection_info.get(connection, {})
last_ping = info.get("last_ping", 0)
if current_time - last_ping > self.heartbeat_interval * 2:
dead_connections.add(connection)
# 죽은 연결 제거
for connection in dead_connections:
await self._remove_connection(connection)
try:
await connection.close()
except:
pass
await asyncio.sleep(self.heartbeat_interval)
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error(f"Heartbeat error: {e}")
await asyncio.sleep(self.heartbeat_interval)
async def _safe_callback(self, callback: Callable, *args):
"""안전한 콜백 호출"""
try:
if asyncio.iscoroutinefunction(callback):
await callback(*args)
else:
callback(*args)
except Exception as e:
self.logger.error(f"Callback error: {e}")
self.stats["errors"] += 1