Skip to main content
Glama
lib.rs15 kB
use core::str; use std::sync::Arc; use base64::{ Engine, engine::general_purpose, }; use jwt_simple::{ common::VerificationOptions, prelude::*, }; use monostate::MustBe; use serde::{ Deserialize, Serialize, }; use si_events::AuthenticationMethodRole; use si_id::{ AuthTokenId, UserPk, WorkspacePk, }; use si_std::CanonicalFile; use telemetry::prelude::*; use thiserror::Error; use tokio::{ fs, io::AsyncReadExt, task::JoinError, }; #[remain::sorted] #[derive(Error, Debug)] pub enum JwtPublicSigningKeyError { #[error("failed to decode base64 string: {0}")] Base64Decode(#[from] base64::DecodeError), #[error("invalid bearer token")] BearerToken, #[error("error creating jwt from config")] FromConfig, #[error("io error: {0}")] Io(#[from] std::io::Error), #[error("JWT error: {0}")] Jwt(#[from] jwt_simple::Error), #[error("{0}")] TaskJoin(#[from] JoinError), #[error("failed to decode ulid: {0}")] UlidDecode(#[from] si_id::ulid::DecodeError), #[error("Unsupported JWT signing algorithm: {0}")] UnsupportedAlgo(String), #[error("failed to build string from utf8: {0}")] Utf8(#[from] std::str::Utf8Error), #[error("failure to verify token: {0}")] Verify(String), #[error("failure to verify against secondary token: first error: {0}, second error: {1}")] VerifySecondaryFail(String, String), } pub type JwtKeyResult<T> = Result<T, JwtPublicSigningKeyError>; #[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct JwtConfig { pub key_file: Option<CanonicalFile>, pub key_base64: Option<String>, pub algo: JwtAlgo, } impl JwtConfig { pub async fn to_pem(self) -> JwtKeyResult<String> { Ok(match (self.key_file.as_ref(), self.key_base64.as_deref()) { (None, Some(key_base64)) => { let buf = general_purpose::STANDARD.decode(key_base64)?; str::from_utf8(&buf)?.to_string() } (Some(key_file), None) => { let mut file = fs::File::open(key_file).await?; let mut buf = String::new(); file.read_to_string(&mut buf).await?; buf } _ => Err(JwtPublicSigningKeyError::FromConfig)?, }) } pub async fn into_verify(self) -> JwtKeyResult<Arc<dyn JwtPublicKeyVerify>> { let algo = self.algo; let pem = self.to_pem().await?; Ok(match algo { JwtAlgo::ES256 => { Arc::new(ES256PublicKey::from_pem(&pem)?) as Arc<dyn JwtPublicKeyVerify> } JwtAlgo::RS256 => { Arc::new(RS256PublicKey::from_pem(&pem)?) as Arc<dyn JwtPublicKeyVerify> } }) } } /** Role indicating what permissions the user should have */ #[derive(Deserialize, Serialize, Debug, Copy, Clone, Eq, PartialEq)] #[serde(rename_all = "camelCase")] pub enum SiJwtClaimRole { Web, Automation, } impl SiJwtClaimRole { pub fn is_superset_of(&self, other: Self) -> bool { match (self, other) { (Self::Web, Self::Web | Self::Automation) => true, (Self::Automation, Self::Automation) => true, (Self::Automation, Self::Web) => false, } } } impl From<SiJwtClaimRole> for AuthenticationMethodRole { fn from(role: SiJwtClaimRole) -> Self { match role { SiJwtClaimRole::Web => AuthenticationMethodRole::Web, SiJwtClaimRole::Automation => AuthenticationMethodRole::Automation, } } } #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] #[serde(untagged)] pub enum SiJwtClaims { V2(SiJwtClaimsV2), #[serde(rename_all = "snake_case")] V1(SiJwtClaimsV1), } /** The whole token */ pub type SiJwt = JWTClaims<SiJwtClaims>; #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] #[serde(rename_all = "camelCase")] pub struct SiJwtClaimsV2 { pub version: MustBe!("2"), pub user_id: UserPk, pub workspace_id: WorkspacePk, pub role: SiJwtClaimRole, } #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] #[serde(rename_all = "snake_case")] pub struct SiJwtClaimsV1 { pub user_pk: UserPk, pub workspace_pk: WorkspacePk, } impl SiJwtClaims { pub fn token_id(token: &SiJwt) -> JwtKeyResult<Option<AuthTokenId>> { match token.jwt_id { Some(ref jwt_id) => Ok(Some(jwt_id.parse()?)), None => Ok(None), } } pub fn user_id(&self) -> UserPk { match self { Self::V2(SiJwtClaimsV2 { user_id, .. }) => *user_id, Self::V1(SiJwtClaimsV1 { user_pk, .. }) => *user_pk, } } pub fn workspace_id(&self) -> WorkspacePk { match self { Self::V2(SiJwtClaimsV2 { workspace_id, .. }) => *workspace_id, Self::V1(SiJwtClaimsV1 { workspace_pk, .. }) => *workspace_pk, } } pub fn role(&self) -> SiJwtClaimRole { match self { Self::V2(SiJwtClaimsV2 { role, .. }) => *role, Self::V1(SiJwtClaimsV1 { .. }) => SiJwtClaimRole::Web, } } pub fn authorized_for(&self, required_role: SiJwtClaimRole) -> bool { self.role().is_superset_of(required_role) } pub fn for_web(user_id: UserPk, workspace_id: WorkspacePk) -> Self { Self::V2(SiJwtClaimsV2 { version: MustBe!("2"), user_id, workspace_id, role: SiJwtClaimRole::Web, }) } pub async fn from_bearer_token( public_key: JwtPublicSigningKeyChain, token: impl AsRef<str>, ) -> JwtKeyResult<SiJwtClaims> { let claims = validate_bearer_token(public_key, token).await?; Ok(claims.custom) } pub async fn from_raw_token( public_key: JwtPublicSigningKeyChain, token: impl Into<String>, ) -> JwtKeyResult<SiJwtClaims> { let claims = validate_raw_token(public_key, token).await?; Ok(claims.custom) } } #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] pub enum JwtAlgo { #[default] ES256, RS256, } pub trait JwtPublicKeyVerify: std::fmt::Debug + Send + Sync { fn algo(&self) -> JwtAlgo; fn verify(&self, token: &str, options: Option<VerificationOptions>) -> JwtKeyResult<SiJwt>; } impl JwtPublicKeyVerify for RS256PublicKey { fn algo(&self) -> JwtAlgo { JwtAlgo::RS256 } fn verify(&self, token: &str, options: Option<VerificationOptions>) -> JwtKeyResult<SiJwt> { self.verify_token(token, options) .map_err(|err| JwtPublicSigningKeyError::Verify(format!("{err}"))) } } impl JwtPublicKeyVerify for ES256PublicKey { fn algo(&self) -> JwtAlgo { JwtAlgo::ES256 } fn verify(&self, token: &str, options: Option<VerificationOptions>) -> JwtKeyResult<SiJwt> { self.verify_token(token, options) .map_err(|err| JwtPublicSigningKeyError::Verify(format!("{err}"))) } } #[derive(Clone, Debug)] pub struct JwtPublicSigningKeyChain { primary: Arc<dyn JwtPublicKeyVerify>, secondary: Option<Arc<dyn JwtPublicKeyVerify>>, } impl JwtPublicSigningKeyChain { pub async fn from_config( primary: JwtConfig, secondary: Option<JwtConfig>, ) -> JwtKeyResult<Self> { Ok(Self { primary: primary.into_verify().await?, secondary: match secondary { Some(jwt_cfg) => Some(jwt_cfg.into_verify().await?), None => None, }, }) } /// Attempt to verify that this token was signed by either the primary or /// secondary key(s) pub fn verify_token( &self, token: &str, options: Option<VerificationOptions>, ) -> JwtKeyResult<SiJwt> { match self.primary.verify(token, options.clone()) { Ok(claims) => Ok(claims), Err(err) => match self.secondary.as_ref() { Some(secondary) => match secondary.verify(token, options) { Ok(claims) => Ok(claims), Err(second_err) => Err(JwtPublicSigningKeyError::VerifySecondaryFail( err.to_string(), second_err.to_string(), )), }, None => Err(err), }, } } } pub async fn validate_bearer_token( public_key: JwtPublicSigningKeyChain, bearer_token: impl AsRef<str>, ) -> JwtKeyResult<SiJwt> { let token = bearer_token .as_ref() .strip_prefix("Bearer ") .ok_or(JwtPublicSigningKeyError::BearerToken)? .to_string(); validate_raw_token(public_key, token).await } #[instrument(level = "debug", skip_all)] pub async fn validate_raw_token( public_key: JwtPublicSigningKeyChain, token: impl Into<String>, ) -> JwtKeyResult<SiJwt> { let token = token.into(); let claims = tokio::task::spawn_blocking(move || public_key.verify_token(&token, None)).await??; Ok(claims) } #[cfg(test)] mod tests { use super::*; // TODO test these with V2 and V1 fn v1_and_v2_claims() -> impl IntoIterator<Item = SiJwtClaims> { [ SiJwtClaims::V1(SiJwtClaimsV1 { user_pk: UserPk::generate(), workspace_pk: WorkspacePk::generate(), }), SiJwtClaims::V2(SiJwtClaimsV2 { version: MustBe!("2"), user_id: UserPk::generate(), workspace_id: WorkspacePk::generate(), role: SiJwtClaimRole::Web, }), ] } #[tokio::test] async fn validate_with_primary_rs256() { for si_claim in v1_and_v2_claims() { println!("generating key..."); let key_pair = RS256KeyPair::generate(2048).expect("generate key pair"); println!("done"); let pub_key = key_pair.public_key(); let pub_key_pem = pub_key.to_pem().expect("get pub key pem"); let pub_key_base64 = general_purpose::STANDARD.encode(pub_key_pem); let claims = JWTClaims { issued_at: None, expires_at: None, invalid_before: None, issuer: None, subject: None, audiences: None, jwt_id: None, nonce: None, custom: si_claim.clone(), }; let signed = key_pair.sign(claims).expect("sign the key"); let bearer_token = format!("Bearer {signed}"); let primary_cfg = JwtConfig { key_file: None, key_base64: Some(pub_key_base64), algo: JwtAlgo::RS256, }; let key_chain = JwtPublicSigningKeyChain::from_config(primary_cfg, None) .await .expect("make key chain"); let claims = validate_bearer_token(key_chain, &bearer_token) .await .expect("should validate"); assert_eq!(si_claim, claims.custom); } } #[tokio::test] async fn validate_with_primary_es256() { for si_claim in v1_and_v2_claims() { println!("generating key..."); let key_pair = ES256KeyPair::generate(); let key_pair_2 = ES256KeyPair::generate(); println!("done"); let pub_key = key_pair.public_key(); let pub_key_pem = pub_key.to_pem().expect("get pub key pem"); let pub_key_base64 = general_purpose::STANDARD.encode(pub_key_pem); let claims = JWTClaims { issued_at: None, expires_at: None, invalid_before: None, issuer: None, subject: None, audiences: None, jwt_id: None, nonce: None, custom: si_claim.clone(), }; let signed = key_pair.sign(claims.clone()).expect("sign the key"); let bearer_token = format!("Bearer {signed}"); let primary_cfg = JwtConfig { key_file: None, key_base64: Some(pub_key_base64), algo: JwtAlgo::ES256, }; let key_chain = JwtPublicSigningKeyChain::from_config(primary_cfg, None) .await .expect("make key chain"); let claims = validate_bearer_token(key_chain.clone(), &bearer_token) .await .expect("should validate"); assert_eq!(si_claim, claims.custom); // Just confirm it fails with the wrong key let signed_bad = key_pair_2.sign(claims).expect("sign the key"); let bearer_bad = format!("Bearer {signed_bad}"); let result = validate_bearer_token(key_chain, &bearer_bad).await; assert!(result.is_err()); } } #[tokio::test] async fn validate_with_secondary_rs256() { for si_claim in v1_and_v2_claims() { println!("generating keys..."); let key_pair_es256 = ES256KeyPair::generate(); let key_pair_rs256 = RS256KeyPair::generate(2048).expect("generate rs256 key"); println!("done"); let pub_key_es256 = key_pair_es256.public_key(); let pub_key_pem = pub_key_es256.to_pem().expect("get pub key pem"); let pub_key_base64_es256 = general_purpose::STANDARD.encode(pub_key_pem); let pub_key_rs256 = key_pair_rs256.public_key(); let pub_key_pem = pub_key_rs256.to_pem().expect("get pub key pem"); let pub_key_base64_rs256 = general_purpose::STANDARD.encode(pub_key_pem); let claims = JWTClaims { issued_at: None, expires_at: None, invalid_before: None, issuer: None, subject: None, audiences: None, jwt_id: None, nonce: None, custom: si_claim.clone(), }; let signed = key_pair_rs256.sign(claims.clone()).expect("sign the key"); let bearer_token = format!("Bearer {signed}"); let primary_cfg = JwtConfig { key_file: None, key_base64: Some(pub_key_base64_es256), algo: JwtAlgo::ES256, }; let secondary_cfg = JwtConfig { key_file: None, key_base64: Some(pub_key_base64_rs256), algo: JwtAlgo::RS256, }; let key_chain = JwtPublicSigningKeyChain::from_config(primary_cfg, Some(secondary_cfg)) .await .expect("make key chain"); let claims = validate_bearer_token(key_chain.clone(), &bearer_token) .await .expect("should validate"); assert_eq!(si_claim, claims.custom); } } }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/systeminit/si'

If you have feedback or need assistance with the MCP directory API, please join our Discord server