"""
Cryptographic utilities for Zintlr MCP Server.
Handles JWT token decryption for the Zintlr authentication flow.
Supports both standard JSON payloads and string payloads (nested JWTs).
"""
import base64
import hashlib
import hmac
import json
import logging
from typing import Any
import jwt
from app.config import settings
logger = logging.getLogger(__name__)
class JWTDecryptionError(Exception):
"""Raised when JWT decryption fails."""
pass
def _base64url_decode(data: str) -> bytes:
"""Decode base64url encoded data with proper padding."""
padding = 4 - len(data) % 4
if padding != 4:
data += "=" * padding
return base64.urlsafe_b64decode(data)
def _verify_hs256_signature(header_b64: str, payload_b64: str, signature_b64: str, secret: str) -> bool:
"""Verify HMAC-SHA256 signature for a JWT."""
signing_input = f"{header_b64}.{payload_b64}".encode("utf-8")
expected_sig = hmac.new(
secret.encode("utf-8"),
signing_input,
hashlib.sha256
).digest()
try:
provided_sig = _base64url_decode(signature_b64)
return hmac.compare_digest(expected_sig, provided_sig)
except Exception:
return False
def _decode_string_payload_jwt(token: str, secret: str) -> str | None:
"""
Manually decode a JWT with a string payload (non-JSON).
Zintlr wraps JWTs inside other JWTs, where the payload is a raw string
rather than a JSON object. PyJWT doesn't handle this, so we verify
and decode manually.
"""
parts = token.split(".")
if len(parts) != 3:
return None
header_b64, payload_b64, signature_b64 = parts
# Verify signature
if not _verify_hs256_signature(header_b64, payload_b64, signature_b64, secret):
logger.warning("JWT signature verification failed")
return None
# Decode payload
try:
payload_bytes = _base64url_decode(payload_b64)
payload_str = payload_bytes.decode("utf-8")
# Try JSON first, fall back to raw string
try:
return json.loads(payload_str)
except json.JSONDecodeError:
return payload_str
except Exception as e:
logger.warning(f"Failed to decode JWT payload: {e}")
return None
def verify_and_decrypt_jwt(token: str, secret: str | None = None) -> Any | None:
"""
Verify and decrypt a JWT token.
Handles both standard JSON payloads and string payloads (nested JWTs)
that Zintlr uses for its authentication tokens.
Args:
token: The JWT token to decrypt
secret: The secret key (defaults to CIPHER_SECRET from settings)
Returns:
The decoded payload if valid, None otherwise
Raises:
ValueError: If CIPHER_SECRET is not configured
"""
if not token:
return None
secret = secret or settings.cipher_secret
if not secret:
raise ValueError("CIPHER_SECRET is not configured")
# Try standard PyJWT decode first (handles JSON payloads)
try:
return jwt.decode(token, secret, algorithms=["HS256"])
except jwt.exceptions.DecodeError as e:
# Handle non-JSON payloads (string payloads like nested JWTs)
error_msg = str(e)
if "Invalid payload string" in error_msg or "Expecting value" in error_msg:
return _decode_string_payload_jwt(token, secret)
logger.warning(f"JWT decode error: {e}")
return None
except jwt.ExpiredSignatureError:
logger.warning("JWT token has expired")
return None
except jwt.InvalidTokenError as e:
logger.warning(f"Invalid JWT token: {e}")
return None
except Exception as e:
logger.error(f"Unexpected JWT decode error: {e}")
return None
def decrypt_user_tokens(encrypted_key: str, encrypted_access_token: str) -> tuple[Any, Any]:
"""
Decrypt both key and access_token from user's session.
Args:
encrypted_key: The encrypted 'key' value
encrypted_access_token: The encrypted 'access_token' value
Returns:
Tuple of (decrypted_key, decrypted_access_token)
Raises:
JWTDecryptionError: If decryption fails for either token
"""
decrypted_key = verify_and_decrypt_jwt(encrypted_key)
if decrypted_key is None:
raise JWTDecryptionError("Failed to decrypt key token")
decrypted_access_token = verify_and_decrypt_jwt(encrypted_access_token)
if decrypted_access_token is None:
raise JWTDecryptionError("Failed to decrypt access_token")
return decrypted_key, decrypted_access_token