MCP Terminal Server

// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 package firebase import ( "context" "errors" "testing" "firebase.google.com/go/v4/auth" "github.com/firebase/genkit/go/genkit" ) type mockAuthClient struct { verifyIDTokenFunc func(context.Context, string) (*auth.Token, error) } func (m *mockAuthClient) VerifyIDToken(ctx context.Context, token string) (*auth.Token, error) { return m.verifyIDTokenFunc(ctx, token) } func TestProvideAuthContext(t *testing.T) { t.Parallel() ctx := context.Background() tests := []struct { name string authHeader string required bool mockToken *auth.Token mockError error expectedUID string expectedError string }{ { name: "Valid token", authHeader: "Bearer validtoken", required: true, mockToken: &auth.Token{ UID: "user123", Firebase: auth.FirebaseInfo{ SignInProvider: "custom", }, }, mockError: nil, expectedUID: "user123", expectedError: "", }, { name: "Missing header when required", authHeader: "", required: true, expectedUID: "", expectedError: "authorization header is required but not provided", }, { name: "Missing header when not required", authHeader: "", required: false, expectedUID: "", expectedError: "", }, { name: "Invalid header format", authHeader: "InvalidBearer token", required: true, expectedUID: "", expectedError: "invalid authorization header format", }, { name: "Token verification error", authHeader: "Bearer invalidtoken", required: true, mockToken: nil, mockError: errors.New("invalid token"), expectedUID: "", expectedError: "error verifying ID token: invalid token", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockClient := &mockAuthClient{ verifyIDTokenFunc: func(ctx context.Context, token string) (*auth.Token, error) { if token == "validtoken" { return tt.mockToken, tt.mockError } return nil, tt.mockError }, } auth := &firebaseAuth{ client: mockClient, required: tt.required, } newCtx, err := auth.ProvideAuthContext(ctx, tt.authHeader) if tt.expectedError != "" { if err == nil || err.Error() != tt.expectedError { t.Errorf("Expected error %q, got %v", tt.expectedError, err) } } else if err != nil { t.Errorf("Unexpected error: %v", err) } if tt.expectedUID != "" { authContext := auth.FromContext(newCtx) if authContext == nil { t.Errorf("Expected non-nil auth context") } else { uid, ok := authContext["uid"].(string) if !ok { t.Errorf("Expected 'uid' to be a string, got %T", authContext["uid"]) } else if uid != tt.expectedUID { t.Errorf("Expected UID %q, got %q", tt.expectedUID, uid) } } } else if auth.FromContext(newCtx) != nil && tt.authHeader != "" { t.Errorf("Expected nil auth context, but got non-nil") } }) } } func TestCheckAuthPolicy(t *testing.T) { t.Parallel() tests := []struct { name string authContext genkit.AuthContext input any required bool policy func(genkit.AuthContext, any) error expectedError string }{ { name: "Valid auth context and policy", authContext: map[string]any{"uid": "user123"}, input: "test input", required: true, policy: func(authContext genkit.AuthContext, in any) error { return nil }, expectedError: "", }, { name: "Policy error", authContext: map[string]any{"uid": "user123"}, input: "test input", required: true, policy: func(authContext genkit.AuthContext, in any) error { return errors.New("policy error") }, expectedError: "policy error", }, { name: "Missing auth context when required", authContext: nil, input: "test input", required: true, policy: nil, expectedError: "auth is required", }, { name: "Missing auth context when not required", authContext: nil, input: "test input", required: false, policy: nil, expectedError: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { auth := &firebaseAuth{ required: tt.required, policy: tt.policy, } ctx := context.Background() if tt.authContext != nil { ctx = auth.NewContext(ctx, tt.authContext) } err := auth.CheckAuthPolicy(ctx, tt.input) if tt.expectedError != "" { if err == nil || err.Error() != tt.expectedError { t.Errorf("Expected error %q, got %v", tt.expectedError, err) } } else if err != nil { t.Errorf("Unexpected error: %v", err) } }) } }