"""
Dependencies for FastAPI dependency injection.
"""
import json
import logging
import os
import re
import socket
from datetime import datetime
from pathlib import Path
from typing import Annotated, Generator
from urllib.parse import urlparse
from fastapi import Depends, Request
from pydantic import AnyUrl
from pyspark.sql import SparkSession
# Use MCP server's local copy of spark session utilities
# (copied from berdl_notebook_utils but adapted for shared multi-user service)
from src.delta_lake.setup_spark_session import get_spark_session as _get_spark_session
from src.service import app_state
from src.service.http_bearer import KBaseHTTPBearer
from src.settings import BERDLSettings, get_settings
# Initialize the KBase auth dependency for use in routes
auth = KBaseHTTPBearer()
# Configure logging
logger = logging.getLogger(__name__)
def sanitize_k8s_name(name: str) -> str:
"""
Sanitize a string to be Kubernetes DNS-1123 subdomain compliant.
Kubernetes resource names must:
- Consist of lowercase alphanumeric characters, '-', or '.'
- Start and end with an alphanumeric character
- Be at most 253 characters long
Args:
name: The string to sanitize (e.g., username with underscores)
Returns:
A DNS-1123 compliant string (replaces underscores with hyphens)
"""
# Replace underscores and other invalid characters with hyphens
sanitized = re.sub(r"[^a-z0-9.-]", "-", name.lower())
# Ensure it starts and ends with alphanumeric
sanitized = re.sub(r"^[^a-z0-9]+", "", sanitized)
sanitized = re.sub(r"[^a-z0-9]+$", "", sanitized)
# Collapse multiple consecutive hyphens
sanitized = re.sub(r"-+", "-", sanitized)
# Truncate to 253 characters (K8s limit)
return sanitized[:253]
DEFAULT_SPARK_POOL = "default"
SPARK_CONNECT_PORT = "15002"
def read_user_minio_credentials(username: str) -> tuple[str, str]:
"""
Read user's MinIO credentials from their home directory.
Each user has a .berdl_minio_credentials file in their home directory with format:
{"username": "user", "access_key": "key", "secret_key": "secret"}
Args:
username: KBase username
Returns:
Tuple of (access_key, secret_key)
Raises:
FileNotFoundError: If credentials file doesn't exist
ValueError: If credentials file is malformed
"""
# Construct path to credentials file
creds_path = Path(f"/home/{username}/.berdl_minio_credentials")
logger.debug(f"Reading MinIO credentials from: {creds_path}")
if not creds_path.exists():
raise FileNotFoundError(
f"MinIO credentials file not found at {creds_path}. "
f"User {username} must have .berdl_minio_credentials in their home directory."
)
try:
with open(creds_path, "r") as f:
creds = json.load(f)
access_key = creds.get("access_key")
secret_key = creds.get("secret_key")
if not access_key or not secret_key:
raise ValueError(
f"Invalid credentials format in {creds_path}. "
f'Expected: {{"username": "user", "access_key": "key", "secret_key": "secret"}}'
)
logger.info(f"Successfully loaded MinIO credentials for user: {username}")
return access_key, secret_key
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse MinIO credentials file {creds_path}: {e}")
except Exception as e:
logger.error(f"Error reading MinIO credentials for {username}: {e}")
raise
def get_user_from_request(request: Request) -> str:
"""
Extract the authenticated user from the request state.
The user is set by the AuthMiddleware after validating the Bearer token.
Args:
request: FastAPI request object
Returns:
Username of the authenticated user
Raises:
Exception: If user is not authenticated
"""
user = app_state.get_request_user(request)
if user is None:
raise Exception("User not authenticated. Authorization header required.")
return user.user
def construct_user_spark_connect_url(username: str) -> str:
"""
Construct the Spark Connect URL for a specific user's notebook pod.
In BERDL, each user has their own notebook pod with a Spark Connect server.
The URL pattern differs between environments:
- Docker Compose (local dev): sc://spark-notebook:15002 or sc://spark-notebook-{username}:15002
- Kubernetes (prod/stage/dev): sc://jupyter-{sanitized-username}.jupyterhub-{env}:15002
Args:
username: KBase username (may contain underscores or special characters)
Returns:
User-specific Spark Connect URL with DNS-safe username
Notes:
For docker-compose local development, service names don't follow the username pattern.
Use the SPARK_CONNECT_URL_TEMPLATE environment variable to override the default pattern.
For Kubernetes, the MCP server is in namespace (dev/prod/stage) and notebooks are in
namespace (jupyterhub-dev/jupyterhub-prod/jupyterhub-stage), so we need cross-namespace DNS.
IMPORTANT: The username is sanitized to be DNS-1123 compliant (underscores → hyphens)
to match the Kubernetes Service name created by JupyterHub.
"""
# Check if there's a custom template (useful for docker-compose)
template = os.getenv("SPARK_CONNECT_URL_TEMPLATE")
if template:
# Template should contain {username} placeholder
# Example: "sc://spark-notebook:15002" (no placeholder = shared)
# Example: "sc://spark-notebook-{username}:15002"
# Note: For templates, we use the sanitized username to match Kubernetes Service names
sanitized_username = sanitize_k8s_name(username)
url = template.format(username=sanitized_username)
logger.info(
f"Using custom Spark Connect URL template: {url} (username: {username} → {sanitized_username})"
)
return url
# For Kubernetes: need to determine the environment and construct cross-namespace DNS
# Environment can be dev, prod, or stage
k8s_env = os.getenv("K8S_ENVIRONMENT", "dev") # Default to dev if not specified
# Sanitize username for DNS-1123 compliance (e.g., tian_gu_test → tian-gu-test)
# This must match the Service name created by JupyterHub's spark_connect_service.py
sanitized_username = sanitize_k8s_name(username)
# Cross-namespace DNS pattern: {service}.{namespace}.svc.cluster.local
# But short form works too: {service}.{namespace}
notebook_namespace = f"jupyterhub-{k8s_env}"
url = f"sc://jupyter-{sanitized_username}.{notebook_namespace}:{SPARK_CONNECT_PORT}"
logger.info(
f"Using Kubernetes cross-namespace Spark Connect URL: {url} (username: {username} → {sanitized_username})"
)
return url
def is_spark_connect_reachable(spark_connect_url: str, timeout: float = 1.0) -> bool:
"""
Quick TCP check if Spark Connect server is reachable.
Args:
spark_connect_url: Spark Connect URL (e.g., "sc://jupyter-user.namespace:15002")
timeout: Connection timeout in seconds (default: 1.0)
Returns:
True if port is reachable, False otherwise
"""
try:
# Parse URL to extract host and port
# Format: sc://host:port
url_str = spark_connect_url.replace("sc://", "tcp://")
parsed = urlparse(url_str)
host = parsed.hostname
port = parsed.port or 15002 # Default Spark Connect port
if not host:
logger.debug(f"Failed to parse hostname from URL: {spark_connect_url}")
return False
# Attempt TCP connection
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
result = sock.connect_ex((host, port))
sock.close()
return result == 0
except Exception as e:
logger.debug(f"TCP check failed for {spark_connect_url}: {e}")
return False
def get_spark_session(
request: Request,
settings: Annotated[BERDLSettings, Depends(get_settings)],
) -> Generator[SparkSession, None, None]:
"""
Get a SparkSession instance configured for the authenticated user with automatic cleanup.
This function tries to connect to the user's personal Spark Connect server first.
If unavailable, it falls back to a shared Spark cluster. The session is created
fresh for each request with the user's MinIO credentials, ensuring proper isolation.
The session is automatically stopped after the request completes via generator cleanup.
Connection Strategy:
1. Try user's Spark Connect server (sc://jupyter-{username}:15002)
2. Fall back to shared Spark cluster (spark://sharedsparkclustermaster.prod:7077)
Usage in endpoints:
@app.get("/databases")
def get_databases(spark: Annotated[SparkSession, Depends(get_spark_session)]):
# Use spark here
databases = spark.sql("SHOW DATABASES").collect()
return {"databases": [db.databaseName for db in databases]}
# spark.stop() is automatically called after return
Args:
request: FastAPI request object (used to extract authenticated user)
settings: BERDL settings from environment variables
Yields:
SparkSession configured for the user (either via Connect or direct cluster)
Raises:
Exception: If user is not authenticated or both connection methods fail
"""
spark = None
try:
# Get authenticated user from request
username = get_user_from_request(request)
logger.info(f"Creating Spark session for user: {username}")
# Read user's MinIO credentials from their home directory
try:
minio_access_key, minio_secret_key = read_user_minio_credentials(username)
logger.debug(f"Loaded MinIO credentials for user {username}")
except FileNotFoundError as e:
logger.error(f"MinIO credentials file not found for {username}: {e}")
raise Exception(
f"Cannot create Spark session: MinIO credentials file not found for user {username}. "
f"Ensure .berdl_minio_credentials exists in user's home directory at /home/{username}/.berdl_minio_credentials"
)
except Exception as e:
logger.error(
f"Failed to load MinIO credentials for {username}: {type(e).__name__}: {e}",
exc_info=True,
)
raise Exception(
f"Cannot create Spark session: Error reading MinIO credentials for user {username}: {type(e).__name__}: {e}"
)
# Build base user-specific settings
base_user_settings = {
"USER": username,
"MINIO_ACCESS_KEY": minio_access_key,
"MINIO_SECRET_KEY": minio_secret_key,
"MINIO_ENDPOINT_URL": settings.MINIO_ENDPOINT_URL,
"MINIO_SECURE": settings.MINIO_SECURE,
"SPARK_HOME": settings.SPARK_HOME,
"SPARK_MASTER_URL": settings.SPARK_MASTER_URL,
"BERDL_HIVE_METASTORE_URI": settings.BERDL_HIVE_METASTORE_URI,
"SPARK_WORKER_COUNT": settings.SPARK_WORKER_COUNT,
"SPARK_WORKER_CORES": settings.SPARK_WORKER_CORES,
"SPARK_WORKER_MEMORY": settings.SPARK_WORKER_MEMORY,
"SPARK_MASTER_CORES": settings.SPARK_MASTER_CORES,
"SPARK_MASTER_MEMORY": settings.SPARK_MASTER_MEMORY,
"GOVERNANCE_API_URL": settings.GOVERNANCE_API_URL,
"BERDL_POD_IP": settings.BERDL_POD_IP,
}
# Try Spark Connect first with TCP pre-flight check
spark_connect_url = construct_user_spark_connect_url(username)
logger.info(f"Checking Spark Connect availability: {spark_connect_url}")
# Quick TCP check to see if Spark Connect port is reachable
if is_spark_connect_reachable(spark_connect_url, timeout=1.0):
logger.info(
f"Spark Connect port reachable, attempting connection: {spark_connect_url}"
)
user_settings = BERDLSettings(
SPARK_CONNECT_URL=AnyUrl(spark_connect_url),
**base_user_settings,
)
spark = _get_spark_session(
app_name=f"datalake_mcp_server_{username}",
settings=user_settings,
use_spark_connect=True,
)
logger.info(f"✅ Connected via Spark Connect for user {username}")
else:
logger.info("Spark Connect port unreachable, using shared Spark cluster")
# Use shared cluster master URL
shared_master_url = os.getenv(
"SHARED_SPARK_MASTER_URL",
"spark://sharedsparkclustermaster.prod:7077",
)
# Create fallback settings with updated SPARK_MASTER_URL
fallback_settings_dict = base_user_settings.copy()
fallback_settings_dict["SPARK_MASTER_URL"] = AnyUrl(shared_master_url)
# Use a dummy connect URL to satisfy Pydantic validation
fallback_settings_dict["SPARK_CONNECT_URL"] = AnyUrl("sc://localhost:15002")
# Ensure BERDL_POD_IP is set for legacy mode
if not fallback_settings_dict.get("BERDL_POD_IP"):
fallback_settings_dict["BERDL_POD_IP"] = (
"0.0.0.0" # Let Spark auto-detect
)
fallback_settings = BERDLSettings(**fallback_settings_dict)
# Note: SPARK_REMOTE env var handling is done within _get_spark_session
# when use_spark_connect=False to ensure it's cleared at the right time
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
spark = _get_spark_session(
app_name=f"datalake_mcp_server_{username}_{timestamp}",
settings=fallback_settings,
use_spark_connect=False,
)
logger.info(
f"✅ Connected via shared Spark cluster for user {username} at {shared_master_url}"
)
# Yield the spark session to the endpoint
logger.debug("Spark session created, yielding to endpoint")
yield spark
finally:
# Always stop the session, even if an exception occurred
if spark is not None:
try:
logger.info("Stopping Spark session (cleanup)")
spark.stop()
logger.debug("Spark session stopped successfully")
except Exception as e:
logger.error(f"Error stopping Spark session: {e}", exc_info=True)