"""
JWT-based authentication for WebSocket connections.
Provides token generation and validation for secure node authentication.
"""
import time
import hmac
import hashlib
import json
import base64
from typing import Optional, Dict, Any
class JWTAuth:
"""
Simple JWT authentication without external dependencies.
Uses HMAC-SHA256 for signing.
"""
def __init__(self, secret_key: str):
"""
Initialize JWT authenticator.
Args:
secret_key: Secret key for signing tokens
"""
self.secret_key = secret_key.encode('utf-8')
def generate_token(
self,
node_id: str,
role: str,
expiry_seconds: int = 3600
) -> str:
"""
Generate JWT token for a node.
Args:
node_id: Unique node identifier
role: Node role (bootstrap_manager, developer, etc.)
expiry_seconds: Token expiry time in seconds
Returns:
JWT token string
"""
now = int(time.time())
# Header
header = {
'alg': 'HS256',
'typ': 'JWT'
}
# Payload
payload = {
'sub': node_id,
'role': role,
'iat': now,
'exp': now + expiry_seconds
}
# Encode header and payload
header_b64 = self._base64_encode(json.dumps(header))
payload_b64 = self._base64_encode(json.dumps(payload))
# Create signature
message = f"{header_b64}.{payload_b64}".encode('utf-8')
signature = hmac.new(self.secret_key, message, hashlib.sha256).digest()
signature_b64 = self._base64_encode(signature)
# Combine to create token
token = f"{header_b64}.{payload_b64}.{signature_b64}"
return token
def validate_token(self, token: str) -> Optional[Dict[str, Any]]:
"""
Validate JWT token and extract payload.
Args:
token: JWT token string
Returns:
Payload dict if valid, None if invalid
"""
try:
# Split token
parts = token.split('.')
if len(parts) != 3:
return None
header_b64, payload_b64, signature_b64 = parts
# Verify signature
message = f"{header_b64}.{payload_b64}".encode('utf-8')
expected_signature = hmac.new(
self.secret_key,
message,
hashlib.sha256
).digest()
expected_signature_b64 = self._base64_encode(expected_signature)
if signature_b64 != expected_signature_b64:
return None
# Decode payload
payload_json = self._base64_decode(payload_b64)
payload = json.loads(payload_json)
# Check expiry
now = int(time.time())
if payload.get('exp', 0) < now:
return None
return payload
except Exception:
return None
def _base64_encode(self, data: Any) -> str:
"""Base64 URL-safe encoding"""
if isinstance(data, str):
data = data.encode('utf-8')
elif isinstance(data, dict):
data = json.dumps(data).encode('utf-8')
encoded = base64.urlsafe_b64encode(data).decode('utf-8')
return encoded.rstrip('=') # Remove padding
def _base64_decode(self, data: str) -> str:
"""Base64 URL-safe decoding"""
# Add padding if needed
padding = 4 - (len(data) % 4)
if padding != 4:
data += '=' * padding
decoded = base64.urlsafe_b64decode(data.encode('utf-8'))
return decoded.decode('utf-8')
# Example usage
def example():
"""Example JWT authentication usage"""
auth = JWTAuth(secret_key="your-secret-key-here")
# Generate token
token = auth.generate_token(
node_id="node-123",
role="developer",
expiry_seconds=3600
)
print(f"Token: {token}")
# Validate token
payload = auth.validate_token(token)
if payload:
print(f"Valid token for node: {payload['sub']}")
print(f"Role: {payload['role']}")
else:
print("Invalid token")
if __name__ == '__main__':
example()