oauth.py•5.1 kB
from pydantic import BaseModel, Field, HttpUrl
from pydantic.types import UUID4
from pydantic import field_validator as validator
from typing import Optional, List, Union
from datetime import datetime
import uuid
class OAuthClientBase(BaseModel):
"""Base schema for OAuth client data."""
client_name: str
redirect_uris: List[str]
scopes: List[str]
is_confidential: bool = True
@validator("scopes")
def validate_scopes(cls, v):
"""Validate scope values."""
allowed_scopes = ["memories:read", "memories:write", "memories:admin"]
for scope in v:
if scope not in allowed_scopes:
raise ValueError(f"Scope must be one of: {', '.join(allowed_scopes)}")
return v
class OAuthClientCreate(OAuthClientBase):
"""Schema for creating a new OAuth client."""
pass
class OAuthClientUpdate(BaseModel):
"""Schema for updating an OAuth client."""
client_name: Optional[str] = None
redirect_uris: Optional[List[str]] = None
scopes: Optional[List[str]] = None
is_confidential: Optional[bool] = None
@validator("scopes")
def validate_scopes(cls, v):
"""Validate scope values."""
if v is None:
return v
allowed_scopes = ["memories:read", "memories:write", "memories:admin"]
for scope in v:
if scope not in allowed_scopes:
raise ValueError(f"Scope must be one of: {', '.join(allowed_scopes)}")
return v
class OAuthClientInDBBase(OAuthClientBase):
"""Base schema for OAuth client data from database."""
id: UUID4
client_id: UUID4
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class OAuthClient(OAuthClientInDBBase):
"""Schema for OAuth client data returned to client."""
pass
class OAuthClientInDB(OAuthClientInDBBase):
"""Schema for OAuth client data stored in database (includes client secret)."""
client_secret: str
class OAuthClientRegisterResponse(BaseModel):
"""Schema for OAuth client registration response."""
client_id: UUID4
client_secret: str
client_name: str
redirect_uris: List[str]
scopes: List[str]
is_confidential: bool
class AuthorizationRequest(BaseModel):
"""Schema for OAuth authorization request."""
response_type: str
client_id: UUID4
redirect_uri: str
scope: str
state: str
code_challenge: Optional[str] = None
code_challenge_method: Optional[str] = None
@validator("response_type")
def validate_response_type(cls, v):
"""Validate response_type value."""
if v != "code":
raise ValueError("response_type must be 'code'")
return v
@validator("code_challenge_method")
def validate_code_challenge_method(cls, v):
"""Validate code_challenge_method value."""
if v is not None and v != "S256":
raise ValueError("code_challenge_method must be 'S256'")
return v
class TokenRequest(BaseModel):
"""Schema for OAuth token request."""
grant_type: str
code: Optional[str] = None
redirect_uri: Optional[str] = None
client_id: UUID4
client_secret: Optional[str] = None
code_verifier: Optional[str] = None
refresh_token: Optional[str] = None
@validator("grant_type")
def validate_grant_type(cls, v):
"""Validate grant_type value."""
allowed_grant_types = ["authorization_code", "refresh_token"]
if v not in allowed_grant_types:
raise ValueError(f"grant_type must be one of: {', '.join(allowed_grant_types)}")
return v
class TokenResponse(BaseModel):
"""Schema for OAuth token response."""
access_token: str
token_type: str = "bearer"
expires_in: int
refresh_token: str
scope: str
class TokenError(BaseModel):
"""Schema for OAuth token error response."""
error: str
error_description: Optional[str] = None
class OAuthScope(BaseModel):
"""Schema for OAuth scope."""
name: str
description: str
class IntrospectionRequest(BaseModel):
"""Schema for token introspection request."""
token: str
token_type_hint: Optional[str] = None
class IntrospectionResponse(BaseModel):
"""Schema for token introspection response."""
active: bool
client_id: Optional[UUID4] = None
user_id: Optional[UUID4] = None
scope: Optional[str] = None
exp: Optional[int] = None
iat: Optional[int] = None
code_verifier: Optional[str] = None
refresh_token: Optional[str] = None
grant_type: Optional[str] = None
@validator("grant_type")
def validate_grant_type(cls, v):
"""Validate grant_type value."""
allowed_grant_types = ["authorization_code", "refresh_token"]
if v not in allowed_grant_types:
raise ValueError(f"grant_type must be one of: {', '.join(allowed_grant_types)}")
return v
class TokenResponse(BaseModel):
"""Schema for OAuth token response."""
access_token: str
token_type: str
expires_in: int
refresh_token: str
scope: str