vault_client.py•5.75 kB
"""HashiCorp Vault client for JIT token minting."""
import logging
import time
from typing import Optional
import hvac
from hvac.exceptions import VaultError
logger = logging.getLogger(__name__)
class VaultClient:
"""Client for HashiCorp Vault JIT token operations."""
def __init__(
self,
vault_addr: str,
role_id: str,
secret_id: str,
mount_path: str = "auth/approle",
):
"""Initialize Vault client.
Args:
vault_addr: Vault server address
role_id: AppRole Role ID
secret_id: AppRole Secret ID
mount_path: Auth mount path (default: auth/approle)
"""
self.vault_addr = vault_addr
self.role_id = role_id
self.secret_id = secret_id
self.mount_path = mount_path
self.client = hvac.Client(url=vault_addr)
self._token: Optional[str] = None
self._token_expiry: Optional[float] = None
def authenticate(self, max_retries: int = 3, backoff_base: float = 1.0) -> bool:
"""Authenticate with Vault using AppRole.
Args:
max_retries: Maximum number of retry attempts
backoff_base: Base delay for exponential backoff in seconds
Returns:
True if authentication successful, False otherwise
"""
for attempt in range(max_retries):
try:
response = self.client.auth.approle.login(
role_id=self.role_id,
secret_id=self.secret_id,
mount_point=self.mount_path.split("/")[-1],
)
self._token = response["auth"]["client_token"]
# Calculate expiry time (token TTL in seconds)
ttl = response["auth"].get("lease_duration", 3600)
self._token_expiry = time.time() + ttl - 60 # 60s buffer
self.client.token = self._token
logger.info("Vault authentication successful")
return True
except VaultError as e:
logger.warning(f"Vault authentication attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
delay = backoff_base * (2 ** attempt)
logger.info(f"Retrying in {delay} seconds...")
time.sleep(delay)
else:
logger.error("Vault authentication failed after all retries")
return False
except Exception as e:
logger.error(f"Unexpected error during Vault authentication: {e}")
return False
return False
def mint_netbox_token(self, vault_path: str, max_retries: int = 3) -> Optional[str]:
"""Mint a JIT token for NetBox API access.
Args:
vault_path: Vault path where NetBox tokens are stored (e.g., netbox/jit-tokens)
max_retries: Maximum number of retry attempts
Returns:
NetBox API token if successful, None otherwise
"""
# Check if we need to re-authenticate
if not self._token or (
self._token_expiry and time.time() >= self._token_expiry
):
logger.info("Vault token expired or missing, re-authenticating...")
if not self.authenticate():
return None
for attempt in range(max_retries):
try:
# Parse vault path (format: mount_point/path/to/secret)
# For example: "netbox/jit-tokens" -> mount_point="netbox", path="jit-tokens"
path_parts = vault_path.split("/", 1)
if len(path_parts) == 2:
mount_point = path_parts[0]
secret_path = path_parts[1]
else:
# Default to "secret" mount if no mount point specified
mount_point = "secret"
secret_path = vault_path
# Read the NetBox token from Vault
response = self.client.secrets.kv.v2.read_secret_version(
path=secret_path,
mount_point=mount_point,
)
netbox_token = response["data"]["data"].get("token")
if netbox_token:
logger.debug("Successfully retrieved NetBox token from Vault")
return netbox_token
else:
logger.warning("NetBox token not found in Vault response")
return None
except VaultError as e:
logger.warning(
f"Failed to retrieve NetBox token (attempt {attempt + 1}/{max_retries}): {e}"
)
# If authentication error, try re-authenticating
if "permission denied" in str(e).lower() or "invalid" in str(e).lower():
if self.authenticate():
continue
if attempt < max_retries - 1:
time.sleep(1 * (attempt + 1))
else:
logger.error("Failed to retrieve NetBox token after all retries")
return None
except Exception as e:
logger.error(f"Unexpected error retrieving NetBox token: {e}")
return None
return None
def is_authenticated(self) -> bool:
"""Check if client is authenticated and token is valid.
Returns:
True if authenticated and token not expired, False otherwise
"""
return (
self._token is not None
and self._token_expiry is not None
and time.time() < self._token_expiry
)