websocket_client.py•19.4 kB
#!/usr/bin/env python3
"""
Hierarchy Cache WebSocket Client
WebSocket client for receiving real-time Location and Department entity updates
from the Motadata ITSM server and updating the hierarchy caches accordingly.
This client:
1. Authenticates using OAuth2 client_credentials grant
2. Obtains a WebSocket handshake token
3. Connects to the WebSocket server using STOMP protocol
4. Subscribes to python-service notifications
5. Processes Location and Department entity updates
6. Updates the hierarchy caches in real-time
"""
import json
import re
import requests
import websocket
import threading
import time
import logging
import os
from typing import Dict, Any, Optional
from urllib.parse import urlparse
from stompest.protocol import StompParser
logger = logging.getLogger(__name__)
class HierarchyCacheWebSocketClient:
    """
    WebSocket client for hierarchy cache synchronization.
    
    Connects to Motadata ITSM WebSocket server and listens for Location and
    Department entity updates, applying them to the hierarchy caches in real-time.
    """
    
    def __init__(
        self,
        cache_manager,
        server_url: str,
        access_token: str,
        client_id: str = "python-service",
        client_secret: str = "IZqQXkqA1tGLIBRpYs",
        instance_id: str = "python-service-1"
    ):
        """
        Initialize the WebSocket client.
        Args:
            cache_manager: HierarchyCacheManager instance for cache updates
            server_url: Base URL of the Motadata server (e.g., https://your-server.com)
            client_id: OAuth2 client ID
            client_secret: OAuth2 client secret
            instance_id: Unique identifier for this service instance
        """
        self.cache_manager = cache_manager
        self.server_url = server_url.rstrip('/')
        self.client_id = client_id
        self.client_secret = client_secret
        self.instance_id = instance_id
        self.access_token = access_token
        self.ws: Optional[websocket.WebSocketApp] = None
        self.connected = False
        self.connection_established = threading.Event()
        self.connection_failed = threading.Event()
        self.reconnect_attempts = 0
        self.max_reconnect_attempts = 10
        self.reconnect_delay = 5  # seconds
        # Initialize STOMP parser for protocol-compliant frame parsing
        self.stomp_parser = StompParser(version='1.2')
        
    def authenticate(self) -> str:
        """
        Obtain OAuth2 access token using client_credentials grant.
        Returns:
            Access token string
        Raises:
            requests.HTTPError: If authentication fails
        """
        token_url = f"{self.server_url}/api/oauth/token"
        data = {
            'grant_type': 'client_credentials',
            'client_id': self.client_id,
            'client_secret': self.client_secret,
            'scope': self.client_id
        }
        logger.info(f"Authenticating with server: {token_url}")
        try:
            response = requests.post(token_url, data=data, timeout=10)
            response.raise_for_status()
            token_data = response.json()
            self.access_token = token_data['access_token']
            expires_in = token_data.get('expires_in', 172800)
            logger.info(f"Authentication successful. Token expires in {expires_in} seconds")
            return self.access_token
            
        except requests.exceptions.RequestException as e:
            logger.error(f"Authentication failed: {e}")
            raise
    
    def get_handshake_token(self) -> str:
        """
        Get WebSocket handshake token from the server.
        
        Returns:
            Handshake token string
            
        Raises:
            requests.HTTPError: If token request fails
        """
        handshake_url = f"{self.server_url}/api/socket/handshake-token"
        headers = {'Authorization': f'Bearer {self.access_token}'}
        
        logger.info("Requesting WebSocket handshake token")
        
        try:
            response = requests.get(handshake_url, headers=headers, timeout=10)
            response.raise_for_status()
            
            token = response.json()['token']
            logger.info("Handshake token obtained successfully")
            return token
            
        except requests.exceptions.RequestException as e:
            logger.error(f"Failed to get handshake token: {e}")
            raise
    
    def on_message(self, ws, message):
        """
        Handle incoming WebSocket messages using stompest protocol parser.
        Args:
            ws: WebSocket instance
            message: Raw message string
        """
        logger.debug(f"Received raw message: {message[:200]}...")
        # Handle heartbeat (just a newline)
        if message == '\n':
            logger.debug("Heartbeat received")
            return
        try:
            # The server sends incorrect content-length headers, so we remove them
            # to let stompest parse based on the null terminator instead
            message_cleaned = re.sub(r'content-length:[^\n]*\n', '', message, flags=re.IGNORECASE)
            # Convert to bytes if needed (websocket-client may send strings)
            message_bytes = message_cleaned.encode('utf-8') if isinstance(message_cleaned, str) else message_cleaned
            # Feed message to STOMP parser
            self.stomp_parser.add(message_bytes)
            # Process all available frames
            while self.stomp_parser.canRead():
                frame = self.stomp_parser.get()
                if frame.command == 'CONNECTED':
                    logger.info("✅ STOMP connection established")
                    logger.debug(f"CONNECTED headers: {frame.headers}")
                    self.connected = True
                    self.reconnect_attempts = 0
                    self.connection_established.set()
                elif frame.command == 'MESSAGE':
                    logger.debug(f"MESSAGE headers: {frame.headers}")
                    # Parse JSON body
                    try:
                        if frame.body:
                            body_str = frame.body.decode('utf-8')
                            logger.debug(f"Extracted JSON body: {body_str}")
                            payload = json.loads(body_str)
                            logger.info(f"✅ Parsed payload: {payload}")
                            self.handle_entity_update(payload)
                        else:
                            logger.warning("MESSAGE frame has empty body")
                    except json.JSONDecodeError as e:
                        logger.error(f"Failed to parse JSON payload: {e}")
                        logger.debug(f"Frame body: {frame.body}")
                elif frame.command == 'ERROR':
                    logger.error(f"STOMP ERROR frame received")
                    logger.error(f"Headers: {frame.headers}")
                    logger.error(f"Body: {frame.body.decode('utf-8') if frame.body else 'N/A'}")
                else:
                    logger.debug(f"Received STOMP frame: {frame.command}")
        except Exception as e:
            logger.error(f"Error parsing STOMP frame: {e}", exc_info=True)
            logger.debug(f"Raw message: {message}")
    
    def handle_entity_update(self, payload: Dict[str, Any]):
        """
        Process entity update notification and route to appropriate handler.
        The payload structure from the server is:
        {
            "id": <entity_id>,
            "parentId": <parent_id>,
            "name": "<entity_name>"
        }
        Since the server doesn't explicitly indicate entity type in the payload,
        we attempt to update both Location and Department caches. The update
        will succeed for whichever cache contains the entity.
        Args:
            payload: Entity update payload containing id, parentId, name
        """
        try:
            entity_id = payload.get('id')
            parent_id = payload.get('parentId', 0)  # Default to 0 if not provided
            name = payload.get('name')
            model = payload.get('model')
            removed = payload.get('removed', False)
            if entity_id is None or name is None:
                logger.warning(f"Incomplete entity update payload: {payload}")
                return
            logger.info(f"Received entity update: id={entity_id}, name='{name}', parentId={parent_id}")
            # Try updating both caches - the entity will exist in one of them
            if model == "location":
                if removed:
                    self.handle_location_remove(entity_id, name, parent_id)
                else:
                    self.handle_location_update(entity_id, name, parent_id)
            else:
                if removed:
                    self.handle_department_remove(entity_id, name, parent_id)
                else:
                    self.handle_department_update(entity_id, name, parent_id)
        except Exception as e:
            logger.error(f"Error handling entity update: {e}", exc_info=True)
    def handle_location_remove(self, entity_id: int, name: str, parent_id: int) -> bool:
        """
        Handle Location entity remove.
        Args:
            entity_id: Location ID
            name: Location name
            parent_id: Parent location ID (0 for root)
        Returns:
            True if location was remove, False if not found in cache
        """
        try:
            if not self.cache_manager or not self.cache_manager.is_initialized():
                logger.warning("Cache manager not initialized, cannot remove location")
                return False
            location_cache = self.cache_manager.get_location_cache()
            if not location_cache:
                logger.warning("Location cache not available")
                return False
            # Convert parent_id of 0 to None for root nodes
            parent_id_value = None if parent_id == 0 else parent_id
            # Update the node in cache
            success = location_cache.remove_node(entity_id, name, parent_id_value)
            if success:
                logger.info(f"✅ remove Location cache: id={entity_id}, name='{name}', parent={parent_id}")
                # Rebuild indices and paths for updated cache
                location_cache._build_indices()
                location_cache._compute_paths()
            return success
        except Exception as e:
            logger.error(f"Error remove location {entity_id}: {e}", exc_info=True)
            return False
    def handle_location_update(self, entity_id: int, name: str, parent_id: int) -> bool:
        """
        Handle Location entity update.
        Args:
            entity_id: Location ID
            name: Location name
            parent_id: Parent location ID (0 for root)
        Returns:
            True if location was updated, False if not found in cache
        """
        try:
            if not self.cache_manager or not self.cache_manager.is_initialized():
                logger.warning("Cache manager not initialized, cannot update location")
                return False
            location_cache = self.cache_manager.get_location_cache()
            if not location_cache:
                logger.warning("Location cache not available")
                return False
            # Convert parent_id of 0 to None for root nodes
            parent_id_value = None if parent_id == 0 else parent_id
            # Update the node in cache
            success = location_cache.update_node(entity_id, name, parent_id_value)
            if success:
                logger.info(f"✅ Updated Location cache: id={entity_id}, name='{name}', parent={parent_id}")
                # Rebuild indices and paths for updated cache
                location_cache._build_indices()
                location_cache._compute_paths()
            return success
        except Exception as e:
            logger.error(f"Error updating location {entity_id}: {e}", exc_info=True)
            return False
    def handle_department_remove(self, entity_id: int, name: str, parent_id: int) -> bool:
        """
        Handle Department entity remove.
        Args:
            entity_id: Department ID
            name: Department name
            parent_id: Parent department ID (0 for root)
        Returns:
            True if department was remove, False if not found in cache
        """
        try:
            if not self.cache_manager or not self.cache_manager.is_initialized():
                logger.warning("Cache manager not initialized, cannot remove department")
                return False
            department_cache = self.cache_manager.get_department_cache()
            if not department_cache:
                logger.warning("Department cache not available")
                return False
            # Convert parent_id of 0 to None for root nodes
            parent_id_value = None if parent_id == 0 else parent_id
            # Update the node in cache
            success = department_cache.remove_node(entity_id, name, parent_id_value)
            if success:
                logger.info(f"✅ Updated Department cache: id={entity_id}, name='{name}', parent={parent_id}")
                # Rebuild indices and paths for remove cache
                department_cache._build_indices()
                department_cache._compute_paths()
            return success
        except Exception as e:
            logger.error(f"Error remove department {entity_id}: {e}", exc_info=True)
            return False
    def handle_department_update(self, entity_id: int, name: str, parent_id: int) -> bool:
        """
        Handle Department entity update.
        Args:
            entity_id: Department ID
            name: Department name
            parent_id: Parent department ID (0 for root)
        Returns:
            True if department was updated, False if not found in cache
        """
        try:
            if not self.cache_manager or not self.cache_manager.is_initialized():
                logger.warning("Cache manager not initialized, cannot update department")
                return False
            department_cache = self.cache_manager.get_department_cache()
            if not department_cache:
                logger.warning("Department cache not available")
                return False
            # Convert parent_id of 0 to None for root nodes
            parent_id_value = None if parent_id == 0 else parent_id
            # Update the node in cache
            success = department_cache.update_node(entity_id, name, parent_id_value)
            if success:
                logger.info(f"✅ Updated Department cache: id={entity_id}, name='{name}', parent={parent_id}")
                # Rebuild indices and paths for updated cache
                department_cache._build_indices()
                department_cache._compute_paths()
            return success
        except Exception as e:
            logger.error(f"Error updating department {entity_id}: {e}", exc_info=True)
            return False
    
    def on_error(self, ws, error):
        """Handle WebSocket errors."""
        logger.error(f"WebSocket error: {error}")
        self.connection_failed.set()
    
    def on_close(self, ws, close_status_code, close_msg):
        """Handle WebSocket connection close."""
        self.connected = False
        logger.warning(f"WebSocket closed: {close_status_code} - {close_msg}")
        
        # Attempt reconnection
        if self.reconnect_attempts < self.max_reconnect_attempts:
            self.reconnect_attempts += 1
            logger.info(
                f"Attempting reconnection {self.reconnect_attempts}/"
                f"{self.max_reconnect_attempts} in {self.reconnect_delay} seconds..."
            )
            time.sleep(self.reconnect_delay)
            self.connect()
        else:
            logger.error("Max reconnection attempts reached. Giving up.")
            self.connection_failed.set()
    
    def on_open(self, ws):
        """
        Send CONNECT and SUBSCRIBE frames after connection opens.
        
        Args:
            ws: WebSocket instance
        """
        logger.info("WebSocket connection opened")
        
        # Send CONNECT frame with authentication
        connect_headers = {
            "accept-version": "1.2",
            "heart-beat": "10000,10000",
        }
        connect_message = build_stomp_message("CONNECT", connect_headers)
        ws.send(connect_message)
        logger.info("CONNECT frame sent")
        
        # Wait for CONNECTED response
        time.sleep(1)
        
        # Send SUBSCRIBE frame
        subscribe_frame = (
            f"SUBSCRIBE\n"
            f"id:sub-0\n"
            f"destination:/users/execute-task/python-service\n"
            f"model:python_service\n"
            f"refid:1\n"
            f"\n"
            f"\x00"
        )
        ws.send(subscribe_frame)
        logger.info("Subscribed to /users/python-service/execute-task/python-service")
    
    def connect(self):
        """
        Establish WebSocket connection.
        This method:
        1. Authenticates to get OAuth2 token
        2. Gets WebSocket handshake token
        3. Connects to WebSocket endpoint (SockJS format)
        4. Starts message processing loop
        Note: The Spring WebSocket endpoint is configured with SockJS (.withSockJS()),
        so we need to use the SockJS WebSocket endpoint format: /endpoint/websocket
        """
        try:
            # Authenticate and get tokens
            self.authenticate()
            handshake_token = self.get_handshake_token()
            # Build WebSocket URL for SockJS
            # SockJS endpoints require /websocket suffix when using raw WebSocket
            # The server has /api context path, so the full path is /api/public/mtdtsocket/websocket
            parsed_url = urlparse(self.server_url)
            ws_scheme = 'wss' if parsed_url.scheme == 'https' else 'ws'
            # Build WebSocket URL with /api context path
            ws_url = f"{ws_scheme}://{parsed_url.netloc}/api/public/mtdtsocket/websocket?mtdt={handshake_token}"
            logger.info(f"Connecting to WebSocket (SockJS format): {ws_url}")
            # Create WebSocket connection
            self.ws = websocket.WebSocketApp(
                ws_url,
                header=[
                    "Authorization: Bearer " + self.access_token,
                ],
                on_open=self.on_open,
                on_message=self.on_message,
                on_error=self.on_error,
                on_close=self.on_close
            )
            # Run WebSocket (blocking call)
            self.ws.run_forever()
        except Exception as e:
            logger.error(f"Connection failed: {e}")
            self.connection_failed.set()
            raise
def build_stomp_message(command, headers=None, body=""):
    """
    Builds a STOMP message as per protocol
    """
    if headers is None:
        headers = {}
    message = command + "\n"
    for key, value in headers.items():
        message += f"{key}:{value}\n"
    message += "\n" + body + "\u0000"
    return message