"""
Thread Credentials Management
Persistent storage for user credentials (PEC, Microsoft, Infocert) by thread_id
Survives agent memory limits and container restarts
"""
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
import logging
from .models import ThreadCredentials
from .database import get_db
logger = logging.getLogger(__name__)
def save_credentials(
thread_id: str,
credential_type: str,
db: Optional[Session] = None,
expires_in_days: Optional[int] = None,
**kwargs
) -> ThreadCredentials:
"""
Save or update credentials for a thread.
Args:
thread_id: Agent conversation thread ID
credential_type: Type of credential ('pec', 'microsoft')
db: Database session (optional, will create if not provided)
expires_in_days: Optional expiry in days (default: None = never expire)
**kwargs: Credential fields (pec_email, pec_password, user_email, etc.)
Returns:
ThreadCredentials object
Example:
save_credentials(
thread_id="thread_abc123",
credential_type="pec",
pec_email="test@legalmail.it",
pec_password="secret123"
)
"""
db_provided = db is not None
if not db_provided:
db = next(get_db())
try:
# Check if credentials already exist
cred = db.query(ThreadCredentials).filter_by(
thread_id=thread_id,
credential_type=credential_type
).first()
if cred:
# UPDATE existing credentials
for key, value in kwargs.items():
if hasattr(cred, key) and value is not None:
setattr(cred, key, value)
cred.updated_at = datetime.utcnow()
cred.last_used_at = datetime.utcnow()
logger.info(f"β
Updated {credential_type} credentials for thread {thread_id[:20]}...")
else:
# INSERT new credentials
cred = ThreadCredentials(
thread_id=thread_id,
credential_type=credential_type,
**kwargs
)
# Set expiry if specified
if expires_in_days:
cred.expires_at = datetime.utcnow() + timedelta(days=expires_in_days)
db.add(cred)
logger.info(f"πΎ Saved NEW {credential_type} credentials for thread {thread_id[:20]}...")
db.commit()
db.refresh(cred)
return cred
except SQLAlchemyError as e:
logger.error(f"β Failed to save credentials for thread {thread_id}: {e}")
db.rollback()
raise
finally:
if not db_provided:
db.close()
def get_credentials(
thread_id: str,
credential_type: str,
db: Optional[Session] = None
) -> Optional[ThreadCredentials]:
"""
Retrieve credentials for a thread.
Args:
thread_id: Agent conversation thread ID
credential_type: Type of credential ('pec', 'microsoft')
db: Database session (optional)
Returns:
ThreadCredentials object or None if not found
Example:
creds = get_credentials("thread_abc123", "pec")
if creds:
print(creds.pec_email)
"""
db_provided = db is not None
if not db_provided:
db = next(get_db())
try:
cred = db.query(ThreadCredentials).filter_by(
thread_id=thread_id,
credential_type=credential_type
).first()
if cred:
# Check if expired
if cred.expires_at and cred.expires_at < datetime.utcnow():
logger.warning(f"β οΈ Credentials expired for thread {thread_id[:20]}... type={credential_type}")
return None
logger.info(f"β
Found {credential_type} credentials for thread {thread_id[:20]}...")
return cred
else:
logger.debug(f"β No {credential_type} credentials found for thread {thread_id[:20]}...")
return None
except SQLAlchemyError as e:
logger.error(f"β Failed to get credentials for thread {thread_id}: {e}")
return None
finally:
if not db_provided:
db.close()
def update_last_used(
thread_id: str,
credential_type: str,
db: Optional[Session] = None
) -> bool:
"""
Update last_used_at timestamp for credentials.
Called automatically after successful tool execution.
Args:
thread_id: Agent conversation thread ID
credential_type: Type of credential
db: Database session (optional)
Returns:
True if updated, False otherwise
"""
db_provided = db is not None
if not db_provided:
db = next(get_db())
try:
cred = db.query(ThreadCredentials).filter_by(
thread_id=thread_id,
credential_type=credential_type
).first()
if cred:
cred.last_used_at = datetime.utcnow()
db.commit()
return True
return False
except SQLAlchemyError as e:
logger.error(f"β Failed to update last_used for thread {thread_id}: {e}")
db.rollback()
return False
finally:
if not db_provided:
db.close()
def delete_credentials(
thread_id: str,
credential_type: Optional[str] = None,
db: Optional[Session] = None
) -> int:
"""
Delete credentials for a thread.
Args:
thread_id: Agent conversation thread ID
credential_type: Type to delete (optional - if None, deletes ALL for thread)
db: Database session (optional)
Returns:
Number of credentials deleted
"""
db_provided = db is not None
if not db_provided:
db = next(get_db())
try:
query = db.query(ThreadCredentials).filter_by(thread_id=thread_id)
if credential_type:
query = query.filter_by(credential_type=credential_type)
count = query.delete()
db.commit()
logger.info(f"ποΈ Deleted {count} credentials for thread {thread_id[:20]}...")
return count
except SQLAlchemyError as e:
logger.error(f"β Failed to delete credentials for thread {thread_id}: {e}")
db.rollback()
return 0
finally:
if not db_provided:
db.close()
def cleanup_expired_credentials(db: Optional[Session] = None) -> int:
"""
Delete all expired credentials.
Should be run periodically (e.g., daily cron job).
Returns:
Number of credentials deleted
"""
db_provided = db is not None
if not db_provided:
db = next(get_db())
try:
count = db.query(ThreadCredentials).filter(
ThreadCredentials.expires_at < datetime.utcnow()
).delete()
db.commit()
if count > 0:
logger.info(f"ποΈ Cleaned up {count} expired credentials")
return count
except SQLAlchemyError as e:
logger.error(f"β Failed to cleanup expired credentials: {e}")
db.rollback()
return 0
finally:
if not db_provided:
db.close()
def get_all_credentials_for_thread(
thread_id: str,
db: Optional[Session] = None
) -> Dict[str, ThreadCredentials]:
"""
Get all credentials for a thread (all types).
Useful for debugging.
Returns:
Dictionary mapping credential_type -> ThreadCredentials
"""
db_provided = db is not None
if not db_provided:
db = next(get_db())
try:
creds = db.query(ThreadCredentials).filter_by(thread_id=thread_id).all()
return {cred.credential_type: cred for cred in creds}
except SQLAlchemyError as e:
logger.error(f"β Failed to get all credentials for thread {thread_id}: {e}")
return {}
finally:
if not db_provided:
db.close()
# Convenience functions for specific credential types
def save_pec_credentials(thread_id: str, pec_email: str, pec_password: str, **kwargs) -> ThreadCredentials:
"""Save PEC credentials"""
return save_credentials(
thread_id=thread_id,
credential_type="pec",
pec_email=pec_email,
pec_password=pec_password,
**kwargs
)
def get_pec_credentials(thread_id: str) -> Optional[ThreadCredentials]:
"""Get PEC credentials"""
return get_credentials(thread_id, "pec")
def save_microsoft_credentials(thread_id: str, user_email: str, **kwargs) -> ThreadCredentials:
"""Save Microsoft credentials"""
return save_credentials(
thread_id=thread_id,
credential_type="microsoft",
user_email=user_email,
**kwargs
)
def get_microsoft_credentials(thread_id: str) -> Optional[ThreadCredentials]:
"""Get Microsoft credentials"""
return get_credentials(thread_id, "microsoft")