// ABOUTME: Connection management handlers for OAuth providers
// ABOUTME: Handle connection status, disconnection, and connection initiation
//
// SPDX-License-Identifier: MIT OR Apache-2.0
// Copyright (c) 2025 Pierre Fitness Intelligence
use crate::constants::oauth_config::AUTHORIZATION_EXPIRES_MINUTES;
use crate::database_plugins::DatabaseProvider;
use crate::protocols::universal::{UniversalRequest, UniversalResponse, UniversalToolExecutor};
use crate::protocols::ProtocolError;
use crate::tenant::{TenantContext, TenantRole};
use crate::utils::uuid::parse_user_id_for_protocol;
use serde_json::{json, Map, Value};
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use tracing::{error, info};
/// Handle `get_connection_status` tool - check OAuth connection status
#[must_use]
pub fn handle_get_connection_status(
executor: &UniversalToolExecutor,
request: UniversalRequest,
) -> Pin<Box<dyn Future<Output = Result<UniversalResponse, ProtocolError>> + Send + '_>> {
Box::pin(async move {
// Check cancellation at start
if let Some(token) = &request.cancellation_token {
if token.is_cancelled().await {
return Err(ProtocolError::OperationCancelled(
"handle_get_connection_status cancelled by user".to_owned(),
));
}
}
// Parse user ID from request
let user_uuid = parse_user_id_for_protocol(&request.user_id)?;
// Check if a specific provider is requested
if let Some(specific_provider) = request.parameters.get("provider").and_then(Value::as_str)
{
// Single provider mode
let is_connected = matches!(
executor
.auth_service
.get_valid_token(user_uuid, specific_provider, request.tenant_id.as_deref())
.await,
Ok(Some(_))
);
let status = if is_connected {
"connected"
} else {
"disconnected"
};
Ok(UniversalResponse {
success: true,
result: Some(json!({
"provider": specific_provider,
"status": status,
"connected": is_connected
})),
error: None,
metadata: Some({
let mut map = HashMap::new();
map.insert("user_id".to_owned(), Value::String(user_uuid.to_string()));
map.insert(
"provider".to_owned(),
Value::String(specific_provider.to_owned()),
);
map.insert(
"tenant_id".to_owned(),
request.tenant_id.map_or(Value::Null, Value::String),
);
map
}),
})
} else {
// Multi-provider mode - check all supported providers from registry
let providers_to_check = executor.resources.provider_registry.supported_providers();
let mut providers_status = Map::new();
for provider in providers_to_check {
let is_connected = matches!(
executor
.auth_service
.get_valid_token(user_uuid, provider, request.tenant_id.as_deref())
.await,
Ok(Some(_))
);
let status = if is_connected {
"connected"
} else {
"disconnected"
};
providers_status.insert(
provider.to_owned(),
json!({
"connected": is_connected,
"status": status
}),
);
}
Ok(UniversalResponse {
success: true,
result: Some(json!({
"providers": providers_status
})),
error: None,
metadata: Some({
let mut map = HashMap::new();
map.insert("user_id".to_owned(), Value::String(user_uuid.to_string()));
map.insert(
"tenant_id".to_owned(),
request.tenant_id.map_or(Value::Null, Value::String),
);
map
}),
})
}
})
}
/// Handle `disconnect_provider` tool - disconnect user from OAuth provider
#[must_use]
pub fn handle_disconnect_provider(
executor: &UniversalToolExecutor,
request: UniversalRequest,
) -> Pin<Box<dyn Future<Output = Result<UniversalResponse, ProtocolError>> + Send + '_>> {
Box::pin(async move {
// Check cancellation at start
if let Some(token) = &request.cancellation_token {
if token.is_cancelled().await {
return Err(ProtocolError::OperationCancelled(
"handle_disconnect_provider cancelled by user".to_owned(),
));
}
}
// Parse user ID from request
let user_uuid = parse_user_id_for_protocol(&request.user_id)?;
// Extract provider from parameters (required)
let Some(provider) = request.parameters.get("provider").and_then(Value::as_str) else {
let supported = executor
.resources
.provider_registry
.supported_providers()
.join(", ");
return Ok(connection_error(format!(
"Missing required 'provider' parameter. Supported providers: {supported}"
)));
};
// Disconnect by deleting the token directly
let tenant_id_str = request.tenant_id.as_deref().unwrap_or("default");
match (*executor.resources.database)
.delete_user_oauth_token(user_uuid, tenant_id_str, provider)
.await
{
Ok(()) => Ok(UniversalResponse {
success: true,
result: Some(json!({
"provider": provider,
"status": "disconnected",
"message": format!("Successfully disconnected from {provider}")
})),
error: None,
metadata: Some({
let mut map = HashMap::new();
map.insert("user_id".to_owned(), Value::String(user_uuid.to_string()));
map.insert("provider".to_owned(), Value::String(provider.to_owned()));
map.insert(
"tenant_id".to_owned(),
request.tenant_id.map_or(Value::Null, Value::String),
);
map
}),
}),
Err(e) => Ok(UniversalResponse {
success: false,
result: None,
error: Some(format!("Failed to disconnect from {provider}: {e}")),
metadata: Some({
let mut map = HashMap::new();
map.insert("user_id".to_owned(), Value::String(user_uuid.to_string()));
map.insert("provider".to_owned(), Value::String(provider.to_owned()));
map.insert(
"tenant_id".to_owned(),
request.tenant_id.map_or(Value::Null, Value::String),
);
map
}),
}),
}
})
}
/// Build successful OAuth connection response
fn build_oauth_success_response(
user_uuid: uuid::Uuid,
tenant_id: uuid::Uuid,
provider: &str,
authorization_url: &str,
state: &str,
) -> UniversalResponse {
UniversalResponse {
success: true,
result: Some(json!({
"provider": provider,
"authorization_url": authorization_url,
"state": state,
"instructions": format!(
"To connect your {} account:\n\
1. Visit the authorization URL\n\
2. Log in to {} and approve the connection\n\
3. You will be redirected back to complete the connection\n\
4. Once connected, you can access your {} data through Pierre",
provider, provider, provider
),
"expires_in_minutes": AUTHORIZATION_EXPIRES_MINUTES,
"status": "pending_authorization"
})),
error: None,
metadata: Some({
let mut map = HashMap::new();
map.insert("user_id".to_owned(), Value::String(user_uuid.to_string()));
map.insert("tenant_id".to_owned(), Value::String(tenant_id.to_string()));
map.insert("provider".to_owned(), Value::String(provider.to_owned()));
map
}),
}
}
/// Build OAuth error response
fn build_oauth_error_response(provider: &str, error: &str) -> UniversalResponse {
UniversalResponse {
success: false,
result: None,
error: Some(format!(
"Failed to generate authorization URL: {error}. \
Please check that OAuth credentials are configured for provider '{provider}'."
)),
metadata: Some({
let mut map = HashMap::new();
map.insert(
"error_type".to_owned(),
Value::String("oauth_configuration_error".to_owned()),
);
map.insert("provider".to_owned(), Value::String(provider.to_owned()));
map
}),
}
}
/// Create error response for connection operations
#[inline]
fn connection_error(message: impl Into<String>) -> UniversalResponse {
UniversalResponse {
success: false,
result: None,
error: Some(message.into()),
metadata: None,
}
}
/// Handle `connect_provider` tool - initiate OAuth connection flow
///
/// Accepts optional `redirect_url` parameter for mobile app OAuth flows.
/// When provided, the redirect URL is base64 encoded and included in the OAuth state,
/// allowing the server to redirect back to the mobile app after OAuth completes.
#[must_use]
pub fn handle_connect_provider(
executor: &UniversalToolExecutor,
request: UniversalRequest,
) -> Pin<Box<dyn Future<Output = Result<UniversalResponse, ProtocolError>> + Send + '_>> {
Box::pin(async move {
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
if let Some(token) = &request.cancellation_token {
if token.is_cancelled().await {
return Err(ProtocolError::OperationCancelled(
"handle_connect_provider cancelled by user".to_owned(),
));
}
}
let user_uuid = parse_user_id_for_protocol(&request.user_id)?;
let registry = &executor.resources.provider_registry;
let db = &executor.resources.database;
// Extract and validate provider parameter
let Some(provider) = request.parameters.get("provider").and_then(|v| v.as_str()) else {
let supported = registry.supported_providers().join(", ");
return Ok(connection_error(format!(
"Missing required 'provider' parameter. Supported providers: {supported}"
)));
};
if !registry.is_supported(provider) {
let supported = registry.supported_providers().join(", ");
return Ok(connection_error(format!(
"Provider '{provider}' is not supported. Supported providers: {supported}"
)));
}
// Extract optional redirect_url for mobile OAuth flows
let redirect_url = request
.parameters
.get("redirect_url")
.and_then(Value::as_str);
// Validate redirect URL scheme if provided (security check)
if let Some(url) = redirect_url {
let is_valid_scheme = url.starts_with("pierre://")
|| url.starts_with("exp://")
|| url.starts_with("http://localhost")
|| url.starts_with("https://");
if !is_valid_scheme {
return Ok(connection_error(
"Invalid redirect_url scheme. Allowed schemes: pierre://, exp://, http://localhost, https://",
));
}
}
// Get user and extract tenant context
// Try user's tenant_id first, then fall back to request.tenant_id (for chat interface)
let user = match db.get_user(user_uuid).await {
Ok(Some(u)) => u,
Ok(None) => return Ok(connection_error(format!("User {user_uuid} not found"))),
Err(e) => return Ok(connection_error(format!("Database error: {e}"))),
};
let tenant_id = user
.tenant_id
.as_ref()
.and_then(|t| uuid::Uuid::parse_str(t).ok())
.or_else(|| {
request
.tenant_id
.as_ref()
.and_then(|t| uuid::Uuid::parse_str(t).ok())
})
.unwrap_or(user_uuid);
let tenant_name = db
.get_tenant_by_id(tenant_id)
.await
.map_or_else(|_| "Unknown Tenant".to_owned(), |t| t.name);
let ctx = TenantContext {
tenant_id,
user_id: user_uuid,
tenant_name,
user_role: TenantRole::Member,
};
// Build OAuth state with optional redirect URL
// Format: {user_id}:{random_uuid}:{base64_redirect_url} (third part optional)
let state = redirect_url.map_or_else(
|| format!("{}:{}", user_uuid, uuid::Uuid::new_v4()),
|url| {
let encoded_url = URL_SAFE_NO_PAD.encode(url.as_bytes());
format!("{}:{}:{}", user_uuid, uuid::Uuid::new_v4(), encoded_url)
},
);
// Generate OAuth authorization URL
match executor
.resources
.tenant_oauth_client
.get_authorization_url(&ctx, provider, &state, db.as_ref())
.await
{
Ok(url) => {
info!(
"Generated OAuth URL for user {} provider {}{}",
user_uuid,
provider,
if redirect_url.is_some() {
" (mobile flow)"
} else {
""
}
);
Ok(build_oauth_success_response(
user_uuid, tenant_id, provider, &url, &state,
))
}
Err(e) => {
error!("OAuth URL generation failed for {}: {}", provider, e);
Ok(build_oauth_error_response(provider, &e.to_string()))
}
}
})
}