"""AWS authentication helper with IRSA support and credentials caching.
IAM Roles for Service Accounts (IRSA) enables AWS authentication using temporary
credentials obtained via STS AssumeRole instead of static access keys.
Usage:
from lib.aws.aws_auth_helper import create_boto3_session, get_boto3_credentials
from lib.aws.aws_profile_config import get_aws_profile_config
from lib.enums import AwsProfilesList
# Get boto3 session (for AWS SDK clients)
profile = get_aws_profile_config(AwsProfilesList.ATHENA_ANALYTICS_INTERNAL)
session = create_boto3_session(profile)
athena_client = session.client('athena')
# Get raw credentials (for PyArrow, etc)
creds = get_boto3_credentials(profile)
filesystem = pyarrow.fs.S3FileSystem(
access_key=creds.access_key,
secret_key=creds.secret_key,
session_token=creds.session_token,
region=profile.region,
)
For IRSA configuration and architecture details, see IRSA.md
"""
import threading
import time
from dataclasses import dataclass
from typing import Any
import boto3
import settings as settings
from botocore.client import BaseClient
from botocore.exceptions import ClientError
from lib import loggerutils
from lib.aws.aws_profile_config import AWSProfileConfig
logger = loggerutils.getLogger('analytics')
@dataclass
class Boto3Credentials:
"""AWS credentials for libraries that need raw credentials (PyArrow, etc)."""
access_key: str
secret_key: str
session_token: str | None = None
@dataclass
class CachedCredentials:
"""Cached AWS credentials with expiration tracking."""
access_key_id: str
secret_access_key: str
session_token: str
expiration_ts: float
def is_expired(self, buffer_sec: int = settings.AWS_CREDENTIAL_REFRESH_BUFFER_SEC) -> bool:
"""Check if credentials are expired or near expiration.
Args:
buffer_sec: Safety buffer before actual expiration (default from settings)
Returns:
True if credentials expire in less than buffer_sec seconds
"""
return time.time() >= (self.expiration_ts - buffer_sec)
class LRUCredentialCache:
"""Thread-safe LRU cache for AWS credentials."""
def __init__(self, max_size: int = 20):
self._cache: dict[str, CachedCredentials] = {}
self._lock = threading.RLock()
self._max_size = max_size
def get(self, key: str) -> CachedCredentials | None:
"""Get credentials from cache if valid."""
with self._lock:
if key not in self._cache:
return None
creds = self._cache[key]
if creds.is_expired():
del self._cache[key]
logger.info('Evicted expired credentials', extra_data={'profile': key})
return None
self._cache[key] = self._cache.pop(key)
return creds
def set(self, key: str, value: CachedCredentials) -> None:
"""Store credentials in cache with LRU eviction."""
with self._lock:
if key in self._cache:
del self._cache[key]
elif len(self._cache) >= self._max_size:
lru_key = next(iter(self._cache))
del self._cache[lru_key]
logger.info('Evicted LRU credentials', extra_data={'profile': lru_key})
self._cache[key] = value
def invalidate(self, key: str) -> None:
"""Force removal of credentials from cache."""
with self._lock:
if key in self._cache:
del self._cache[key]
_CREDENTIALS_CACHE = LRUCredentialCache()
def invalidate_credentials_cache(profile_name: str) -> None:
"""Force invalidation of cached credentials for a profile.
Use this when AWS returns ExpiredTokenException to ensure fresh
credentials are fetched on next create_boto3_session call.
Args:
profile_name: AWS profile name to invalidate
"""
_CREDENTIALS_CACHE.invalidate(profile_name)
def _assume_role_with_retry(
sts_client: BaseClient,
role_arn: str,
role_session_name: str,
max_retries: int = 5,
) -> dict[str, Any]:
"""Assume role with exponential backoff for transient errors."""
for attempt in range(max_retries):
try:
return sts_client.assume_role(
RoleArn=role_arn,
RoleSessionName=role_session_name,
)
except ClientError as e:
error_code = e.response['Error']['Code']
if attempt < max_retries - 1:
delay = 2**attempt
logger.warning(
'Retry',
data='STS assume role failed, retrying',
extra_data={
'role_arn': role_arn,
'error_code': error_code,
'attempt': attempt + 1,
'delay_sec': delay,
},
)
time.sleep(delay)
else:
logger.error(
'Failed',
data='STS assume role failed after retries',
extra_data={
'role_arn': role_arn,
'attempts': max_retries,
'error_code': error_code,
},
exc_info=True,
)
raise
raise RuntimeError(f'STS assume role failed unexpectedly for {role_arn}')
def create_boto3_session(aws_config: AWSProfileConfig) -> boto3.Session:
"""Create boto3 session with either IRSA or static credentials.
For IRSA, credentials are cached per profile and refreshed when expired
to avoid excessive STS assume_role calls.
Args:
aws_config: AWS profile configuration
Returns:
Configured boto3 Session
"""
if not aws_config.uses_irsa:
logger.info('Using static AWS credentials', extra_data={'profile': aws_config.name.value})
from lib.aws.aws_profile_config import DirectKeyCredentials
assert isinstance(aws_config.creds, DirectKeyCredentials)
return boto3.Session(
aws_access_key_id=aws_config.creds.aws_access_key_id,
aws_secret_access_key=aws_config.creds.aws_secret_access_key,
aws_session_token=aws_config.creds.aws_session_token,
region_name=aws_config.region,
)
profile_name = aws_config.name.value
cached = _CREDENTIALS_CACHE.get(profile_name)
if cached:
logger.info('Using cached AWS credentials', extra_data={'profile': profile_name})
return boto3.Session(
aws_access_key_id=cached.access_key_id,
aws_secret_access_key=cached.secret_access_key,
aws_session_token=cached.session_token,
region_name=aws_config.region,
)
logger.info('Refreshing AWS credentials via assume role', extra_data={'profile': profile_name})
session = boto3.Session(region_name=aws_config.region)
sts_client = session.client('sts')
try:
initial_identity = sts_client.get_caller_identity()
logger.info(
'Initial AWS identity verified',
extra_data={
'profile': profile_name,
'account': initial_identity['Account'],
'user_id': initial_identity['UserId'],
'arn': initial_identity['Arn'],
},
)
except ClientError as e:
logger.error(
'Failed',
data='Cannot verify initial AWS identity - IRSA configuration issue?',
extra_data={
'profile': profile_name,
'error_code': e.response['Error']['Code'],
'message': e.response['Error']['Message'],
},
exc_info=True,
)
raise
from lib.aws.aws_profile_config import IrsaCredentials
assert isinstance(aws_config.creds, IrsaCredentials)
start_time = time.time()
assumed_role = _assume_role_with_retry(
sts_client=sts_client,
role_arn=aws_config.creds.assume_role_arn,
role_session_name=aws_config.creds.assume_role_name,
)
duration_ms = int((time.time() - start_time) * 1000)
credentials = assumed_role['Credentials']
_CREDENTIALS_CACHE.set(
profile_name,
CachedCredentials(
access_key_id=credentials['AccessKeyId'],
secret_access_key=credentials['SecretAccessKey'],
session_token=credentials['SessionToken'],
expiration_ts=credentials['Expiration'].timestamp(),
),
)
logger.info(
'AWS credentials refreshed',
extra_data={'profile': profile_name},
number=duration_ms,
)
return boto3.Session(
aws_access_key_id=credentials['AccessKeyId'],
aws_secret_access_key=credentials['SecretAccessKey'],
aws_session_token=credentials['SessionToken'],
region_name=aws_config.region,
)
def get_boto3_credentials(aws_config: AWSProfileConfig) -> Boto3Credentials:
"""Get AWS credentials for libraries that need raw credentials.
Used for PyArrow and other libraries that don't accept boto3 Session.
Args:
aws_config: AWS profile configuration
Returns:
Boto3Credentials with access_key, secret_key, session_token
"""
session = create_boto3_session(aws_config)
creds = session.get_credentials()
return Boto3Credentials(
access_key=creds.access_key,
secret_key=creds.secret_key,
session_token=creds.token,
)