"""Token Store for secure persistence of authentication tokens.
Uses msal-extensions for encrypted token caching at ~/.sso-mcp-server/token_cache.bin.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from msal_extensions import (
FilePersistence,
PersistedTokenCache,
)
from sso_mcp_server import get_logger
if TYPE_CHECKING:
from pathlib import Path
_logger = get_logger("token_store")
class TokenStore:
"""Secure token persistence using msal-extensions.
Stores OAuth tokens in an encrypted cache file for seamless
session resumption across server restarts.
Attributes:
cache_path: Path to the encrypted token cache file.
"""
def __init__(self, cache_path: Path) -> None:
"""Initialize the token store.
Creates the parent directory if it doesn't exist.
Args:
cache_path: Path to the token cache file.
"""
self._cache_path = cache_path
self._cache: PersistedTokenCache | None = None
# Ensure parent directory exists
cache_path.parent.mkdir(parents=True, exist_ok=True)
_logger.debug("token_store_initialized", cache_path=str(cache_path))
@property
def cache_path(self) -> Path:
"""Get the path to the token cache file.
Returns:
Path to the cache file.
"""
return self._cache_path
def get_cache(self) -> PersistedTokenCache:
"""Get the MSAL token cache.
Creates the cache on first access using FilePersistence
for cross-platform compatibility.
Returns:
PersistedTokenCache instance for use with MSAL.
"""
if self._cache is None:
persistence = FilePersistence(str(self._cache_path))
self._cache = PersistedTokenCache(persistence)
_logger.debug("token_cache_created", cache_path=str(self._cache_path))
return self._cache
def has_cached_tokens(self) -> bool:
"""Check if there are cached tokens.
Returns:
True if tokens exist in the cache, False otherwise.
"""
cache = self.get_cache()
# Serialize and check if there's actual token data
data = cache.serialize()
if not data:
return False
# Check if there are any access tokens in the serialized data
# MSAL cache format includes "AccessToken" key for tokens
import json
try:
cache_data = json.loads(data)
access_tokens = cache_data.get("AccessToken", {})
return len(access_tokens) > 0
except (json.JSONDecodeError, TypeError):
return False
def clear_cache(self) -> None:
"""Clear all cached tokens.
Removes all token data from the cache. The cache file
may remain but will be empty.
"""
cache = self.get_cache()
# Clear by deserializing empty cache data
cache.deserialize('{"AccessToken": {}, "RefreshToken": {}, "IdToken": {}, "Account": {}}')
_logger.info("token_cache_cleared", cache_path=str(self._cache_path))