#!/usr/bin/env python3
"""
Servicio de autenticación OAuth2 y JWT
"""
import os
import secrets
import json
from datetime import datetime, timezone, timedelta
from typing import Optional, Dict, Any, List
from urllib.parse import urlencode, parse_qs, urlparse
import requests
from jose import JWTError, jwt
from passlib.context import CryptContext
from authlib.integrations.requests_client import OAuth2Session
try:
from .auth_models import (
User, AccessToken, RefreshToken, OAuthState, OAuthProvider, UserRole,
OAuthProviderConfig, OAUTH_PROVIDERS_CONFIG,
AuthError, TokenExpiredError, InvalidTokenError, InsufficientScopeError
)
from .access_control import access_control
from .logger import get_logger
except ImportError:
from auth_models import (
User, AccessToken, RefreshToken, OAuthState, OAuthProvider, UserRole,
OAuthProviderConfig, OAUTH_PROVIDERS_CONFIG,
AuthError, TokenExpiredError, InvalidTokenError, InsufficientScopeError
)
from access_control import access_control
from logger import get_logger
logger = get_logger("auth-service")
class AuthService:
"""Servicio de autenticación OAuth2 y JWT"""
def __init__(self):
# Configuración JWT
self.secret_key = os.getenv("JWT_SECRET_KEY", self._generate_secret_key())
self.algorithm = "HS256"
self.access_token_expire_minutes = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "60"))
self.refresh_token_expire_days = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "30"))
# Almacenamiento en memoria (en producción usar base de datos)
self.users: Dict[str, User] = {}
self.access_tokens: Dict[str, AccessToken] = {}
self.refresh_tokens: Dict[str, RefreshToken] = {}
self.oauth_states: Dict[str, OAuthState] = {}
# Configuración de proveedores OAuth2
self.providers: Dict[OAuthProvider, OAuthProviderConfig] = {}
self._load_oauth_providers()
logger.info("Servicio de autenticación inicializado")
logger.info(f"Proveedores OAuth2 configurados: {list(self.providers.keys())}")
def _generate_secret_key(self) -> str:
"""Generar una clave secreta aleatoria"""
key = secrets.token_urlsafe(32)
logger.warning("Usando clave JWT generada automáticamente. En producción, configura JWT_SECRET_KEY")
return key
def _load_oauth_providers(self):
"""Cargar configuración de proveedores OAuth2"""
for provider in OAuthProvider:
client_id = os.getenv(f"{provider.value.upper()}_CLIENT_ID")
client_secret = os.getenv(f"{provider.value.upper()}_CLIENT_SECRET")
# Debug: mostrar qué variables se están buscando
logger.info(f"Buscando {provider.value.upper()}_CLIENT_ID: {'✓' if client_id else '✗'}")
logger.info(f"Buscando {provider.value.upper()}_CLIENT_SECRET: {'✓' if client_secret else '✗'}")
if client_id and client_secret:
config = OAUTH_PROVIDERS_CONFIG[provider]
self.providers[provider] = OAuthProviderConfig(
client_id=client_id,
client_secret=client_secret,
**config
)
logger.info(f"Proveedor {provider.value} configurado")
else:
logger.warning(f"Proveedor {provider.value} no configurado (falta CLIENT_ID o CLIENT_SECRET)")
def get_available_providers(self) -> List[OAuthProvider]:
"""Obtener lista de proveedores disponibles"""
return list(self.providers.keys())
def create_oauth_authorization_url(
self,
provider: OAuthProvider,
redirect_uri: str,
state: Optional[str] = None
) -> tuple[str, str]:
"""Crear URL de autorización OAuth2"""
if provider not in self.providers:
raise AuthError(f"Proveedor {provider.value} no configurado")
config = self.providers[provider]
# Crear estado para prevenir CSRF (30 minutos de expiración)
oauth_state = OAuthState.create_new(provider, redirect_uri, expires_in_minutes=30)
self.oauth_states[oauth_state.state] = oauth_state
logger.info(f"Estado OAuth creado: {oauth_state.state[:8]}... expira en {oauth_state.expires_at}")
# Limpiar estados expirados
self._cleanup_expired_states()
# Construir URL de autorización
# El redirect_uri debe ser la URL del callback del servidor, no la de éxito
# Detectar si estamos en Railway
railway_port = os.getenv("PORT")
if railway_port:
# Estamos en Railway
base_url = "https://mcptavilidateoaut-production.up.railway.app"
else:
# Desarrollo local
base_url = "http://localhost:8001"
callback_uri = f"{base_url}/auth/callback/{provider.value}"
# Debug logging
logger.info(f"OAuth {provider.value} - callback_uri: {callback_uri}")
params = {
"response_type": "code",
"client_id": config.client_id,
"redirect_uri": callback_uri,
"scope": " ".join(config.scopes),
"state": oauth_state.state
}
auth_url = f"{config.authorize_url}?{urlencode(params)}"
logger.info(f"URL de autorización creada para {provider.value}")
return auth_url, oauth_state.state
def handle_oauth_callback(
self,
provider: OAuthProvider,
code: str,
state: str,
redirect_uri: Optional[str]
) -> tuple[str, str]:
"""Manejar callback OAuth2 y crear tokens"""
# Debug: mostrar estados disponibles
logger.info(f"Estados OAuth disponibles: {list(self.oauth_states.keys())}")
# Verificar estado
if state not in self.oauth_states:
logger.error(f"Estado OAuth no encontrado: {state}")
raise AuthError("Estado OAuth2 inválido o expirado")
oauth_state = self.oauth_states[state]
if oauth_state.is_expired():
del self.oauth_states[state]
raise AuthError("Estado OAuth2 expirado")
if oauth_state.provider != provider:
raise AuthError("Proveedor OAuth2 no coincide")
# Intercambiar código por token
config = self.providers[provider]
user_info = self._exchange_code_for_user_info(config, code, redirect_uri, provider)
# Crear o actualizar usuario
user = self._create_or_update_user(provider, user_info)
# Generar tokens
access_token = self._create_access_token(user.id)
refresh_token = self._create_refresh_token(user.id)
# Limpiar estado usado
del self.oauth_states[state]
logger.info(f"Usuario autenticado: {user.email} via {provider.value}")
return access_token, refresh_token
def _exchange_code_for_user_info(
self,
config: OAuthProviderConfig,
code: str,
redirect_uri: Optional[str],
provider: OAuthProvider
) -> Dict[str, Any]:
"""Intercambiar código OAuth2 por información del usuario"""
# Intercambiar código por token de acceso
# El redirect_uri debe ser la URL del callback del servidor, no la de éxito
# Detectar si estamos en Railway
railway_port = os.getenv("PORT")
if railway_port:
# Estamos en Railway
base_url = "https://mcptavilidateoaut-production.up.railway.app"
else:
# Desarrollo local
base_url = "http://localhost:8001"
callback_uri = f"{base_url}/auth/callback/{provider.value}"
# Debug logging
logger.info(f"Token exchange {provider.value} - callback_uri: {callback_uri}")
token_data = {
"grant_type": "authorization_code",
"client_id": config.client_id,
"client_secret": config.client_secret,
"code": code,
"redirect_uri": callback_uri
}
headers = {"Accept": "application/json"}
logger.info(f"Enviando petición de token a: {config.token_url}")
response = requests.post(config.token_url, data=token_data, headers=headers)
if response.status_code != 200:
logger.error(f"Error obteniendo token: {response.text}")
raise AuthError("Error obteniendo token de acceso")
token_info = response.json()
access_token = token_info.get("access_token")
if not access_token:
raise AuthError("No se recibió token de acceso")
# Obtener información del usuario
headers = {
"Authorization": f"Bearer {access_token}",
"Accept": "application/json"
}
user_response = requests.get(config.userinfo_url, headers=headers)
if user_response.status_code != 200:
logger.error(f"Error obteniendo info del usuario: {user_response.text}")
raise AuthError("Error obteniendo información del usuario")
return user_response.json()
def _create_or_update_user(self, provider: OAuthProvider, user_info: Dict[str, Any]) -> User:
"""Crear o actualizar usuario basado en información OAuth2"""
# Extraer información según el proveedor
if provider == OAuthProvider.GOOGLE:
email = user_info.get("email")
name = user_info.get("name")
provider_id = user_info.get("id")
avatar_url = user_info.get("picture")
elif provider == OAuthProvider.GITHUB:
email = user_info.get("email")
# Si GitHub no devuelve email (privacidad habilitada), usar login
if not email:
email = f"{user_info.get('login')}@github.com"
name = user_info.get("name") or user_info.get("login")
provider_id = str(user_info.get("id"))
avatar_url = user_info.get("avatar_url")
elif provider == OAuthProvider.MICROSOFT:
email = user_info.get("mail") or user_info.get("userPrincipalName")
name = user_info.get("displayName")
provider_id = user_info.get("id")
avatar_url = None # Microsoft Graph no proporciona avatar directamente
else:
raise AuthError(f"Proveedor {provider.value} no soportado")
if not email or not provider_id:
raise AuthError("Información de usuario incompleta")
# Asegurar que name no sea None
if not name:
name = "Usuario"
# Buscar usuario existente
existing_user = None
for user in self.users.values():
if user.provider == provider and user.provider_id == provider_id:
existing_user = user
break
# Verificar control de acceso
user_exists = existing_user is not None
access_allowed, access_message = access_control.check_user_access(email, user_exists)
if not access_allowed:
raise AuthError(access_message)
if existing_user:
# Actualizar información del usuario
existing_user.name = name
existing_user.avatar_url = avatar_url
existing_user.last_login = datetime.now(timezone.utc)
logger.info(f"Usuario actualizado: {existing_user.email}")
return existing_user
else:
# Crear nuevo usuario
user = User.create_new(
email=email,
name=name,
provider=provider,
provider_id=provider_id,
avatar_url=avatar_url
)
user.last_login = datetime.now(timezone.utc)
self.users[user.id] = user
logger.info(f"Nuevo usuario creado: {user.email}")
return user
def _create_access_token(self, user_id: str) -> str:
"""Crear token de acceso JWT"""
expire = datetime.now(timezone.utc) + timedelta(minutes=self.access_token_expire_minutes)
payload = {
"sub": user_id,
"exp": expire,
"iat": datetime.now(timezone.utc),
"type": "access_token",
"scopes": ["mcp:read", "mcp:write"]
}
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
# Almacenar token
access_token = AccessToken.create_new(
user_id=user_id,
token=token,
expires_in_seconds=self.access_token_expire_minutes * 60
)
self.access_tokens[token] = access_token
# Limpiar tokens expirados
self._cleanup_expired_tokens()
return token
def _create_refresh_token(self, user_id: str) -> str:
"""Crear token de refresco"""
token = secrets.token_urlsafe(32)
refresh_token = RefreshToken.create_new(
user_id=user_id,
token=token,
expires_in_days=self.refresh_token_expire_days
)
self.refresh_tokens[token] = refresh_token
return token
def verify_access_token(self, token: str) -> User:
"""Verificar token de acceso y devolver usuario"""
try:
# Decodificar JWT
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
user_id = payload.get("sub")
token_type = payload.get("type")
if not user_id or token_type != "access_token":
raise InvalidTokenError()
# Verificar que el token esté en nuestro almacén
if token not in self.access_tokens:
raise InvalidTokenError()
access_token = self.access_tokens[token]
if access_token.is_expired():
del self.access_tokens[token]
raise TokenExpiredError()
# Obtener usuario
if user_id not in self.users:
raise InvalidTokenError()
user = self.users[user_id]
if not user.is_active:
raise AuthError("Usuario desactivado")
return user
except JWTError:
raise InvalidTokenError()
def refresh_access_token(self, refresh_token: str) -> str:
"""Refrescar token de acceso"""
if refresh_token not in self.refresh_tokens:
raise InvalidTokenError()
refresh_token_obj = self.refresh_tokens[refresh_token]
if refresh_token_obj.is_expired() or refresh_token_obj.is_revoked:
if refresh_token in self.refresh_tokens:
del self.refresh_tokens[refresh_token]
raise InvalidTokenError()
# Crear nuevo token de acceso
new_access_token = self._create_access_token(refresh_token_obj.user_id)
logger.info(f"Token de acceso refrescado para usuario: {refresh_token_obj.user_id}")
return new_access_token
def revoke_refresh_token(self, refresh_token: str):
"""Revocar token de refresco"""
if refresh_token in self.refresh_tokens:
self.refresh_tokens[refresh_token].is_revoked = True
logger.info("Token de refresco revocado")
def get_user_by_id(self, user_id: str) -> Optional[User]:
"""Obtener usuario por ID"""
return self.users.get(user_id)
def _cleanup_expired_tokens(self):
"""Limpiar tokens expirados"""
# Limpiar tokens de acceso expirados
expired_access = [
token for token, token_obj in self.access_tokens.items()
if token_obj.is_expired()
]
for token in expired_access:
del self.access_tokens[token]
# Limpiar tokens de refresco expirados
expired_refresh = [
token for token, token_obj in self.refresh_tokens.items()
if token_obj.is_expired()
]
for token in expired_refresh:
del self.refresh_tokens[token]
if expired_access or expired_refresh:
logger.debug(f"Limpieza: {len(expired_access)} access tokens, {len(expired_refresh)} refresh tokens")
def _cleanup_expired_states(self):
"""Limpiar estados OAuth2 expirados"""
expired_states = [
state for state, state_obj in self.oauth_states.items()
if state_obj.is_expired()
]
for state in expired_states:
del self.oauth_states[state]
if expired_states:
logger.debug(f"Limpieza: {len(expired_states)} estados OAuth2 expirados")
def get_stats(self) -> Dict[str, Any]:
"""Obtener estadísticas del servicio"""
return {
"users_count": len(self.users),
"active_access_tokens": len(self.access_tokens),
"active_refresh_tokens": len([t for t in self.refresh_tokens.values() if not t.is_revoked]),
"pending_oauth_states": len(self.oauth_states),
"configured_providers": list(self.providers.keys())
}
# Instancia global del servicio de autenticación
auth_service = AuthService()