package github
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"testing"
"time"
"github.com/github/github-mcp-server/pkg/lockdown"
"github.com/github/github-mcp-server/pkg/raw"
"github.com/github/github-mcp-server/pkg/translations"
"github.com/google/go-github/v79/github"
"github.com/shurcooL/githubv4"
"github.com/stretchr/testify/assert"
)
// stubDeps is a test helper that implements ToolDependencies with configurable behavior.
// Use this when you need to test error paths or when you need closure-based client creation.
type stubDeps struct {
clientFn func(context.Context) (*github.Client, error)
gqlClientFn func(context.Context) (*githubv4.Client, error)
rawClientFn func(context.Context) (*raw.Client, error)
repoAccessCache *lockdown.RepoAccessCache
t translations.TranslationHelperFunc
flags FeatureFlags
contentWindowSize int
}
func (s stubDeps) GetClient(ctx context.Context) (*github.Client, error) {
if s.clientFn != nil {
return s.clientFn(ctx)
}
return nil, nil
}
func (s stubDeps) GetGQLClient(ctx context.Context) (*githubv4.Client, error) {
if s.gqlClientFn != nil {
return s.gqlClientFn(ctx)
}
return nil, nil
}
func (s stubDeps) GetRawClient(ctx context.Context) (*raw.Client, error) {
if s.rawClientFn != nil {
return s.rawClientFn(ctx)
}
return nil, nil
}
func (s stubDeps) GetRepoAccessCache() *lockdown.RepoAccessCache { return s.repoAccessCache }
func (s stubDeps) GetT() translations.TranslationHelperFunc { return s.t }
func (s stubDeps) GetFlags() FeatureFlags { return s.flags }
func (s stubDeps) GetContentWindowSize() int { return s.contentWindowSize }
// Helper functions to create stub client functions for error testing
func stubClientFnFromHTTP(httpClient *http.Client) func(context.Context) (*github.Client, error) {
return func(_ context.Context) (*github.Client, error) {
return github.NewClient(httpClient), nil
}
}
func stubClientFnErr(errMsg string) func(context.Context) (*github.Client, error) {
return func(_ context.Context) (*github.Client, error) {
return nil, errors.New(errMsg)
}
}
func stubGQLClientFnErr(errMsg string) func(context.Context) (*githubv4.Client, error) {
return func(_ context.Context) (*githubv4.Client, error) {
return nil, errors.New(errMsg)
}
}
func stubRepoAccessCache(client *githubv4.Client, ttl time.Duration) *lockdown.RepoAccessCache {
cacheName := fmt.Sprintf("repo-access-cache-test-%d", time.Now().UnixNano())
return lockdown.GetInstance(client, lockdown.WithTTL(ttl), lockdown.WithCacheName(cacheName))
}
func stubFeatureFlags(enabledFlags map[string]bool) FeatureFlags {
return FeatureFlags{
LockdownMode: enabledFlags["lockdown-mode"],
}
}
func badRequestHandler(msg string) http.HandlerFunc {
return func(w http.ResponseWriter, _ *http.Request) {
structuredErrorResponse := github.ErrorResponse{
Message: msg,
}
b, err := json.Marshal(structuredErrorResponse)
if err != nil {
http.Error(w, "failed to marshal error response", http.StatusInternalServerError)
}
http.Error(w, string(b), http.StatusBadRequest)
}
}
func Test_IsAcceptedError(t *testing.T) {
tests := []struct {
name string
err error
expectAccepted bool
}{
{
name: "github AcceptedError",
err: &github.AcceptedError{},
expectAccepted: true,
},
{
name: "regular error",
err: fmt.Errorf("some other error"),
expectAccepted: false,
},
{
name: "nil error",
err: nil,
expectAccepted: false,
},
{
name: "wrapped AcceptedError",
err: fmt.Errorf("wrapped: %w", &github.AcceptedError{}),
expectAccepted: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := isAcceptedError(tc.err)
assert.Equal(t, tc.expectAccepted, result)
})
}
}
func Test_RequiredStringParam(t *testing.T) {
tests := []struct {
name string
params map[string]interface{}
paramName string
expected string
expectError bool
}{
{
name: "valid string parameter",
params: map[string]interface{}{"name": "test-value"},
paramName: "name",
expected: "test-value",
expectError: false,
},
{
name: "missing parameter",
params: map[string]interface{}{},
paramName: "name",
expected: "",
expectError: true,
},
{
name: "empty string parameter",
params: map[string]interface{}{"name": ""},
paramName: "name",
expected: "",
expectError: true,
},
{
name: "wrong type parameter",
params: map[string]interface{}{"name": 123},
paramName: "name",
expected: "",
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result, err := RequiredParam[string](tc.params, tc.paramName)
if tc.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.expected, result)
}
})
}
}
func Test_OptionalStringParam(t *testing.T) {
tests := []struct {
name string
params map[string]interface{}
paramName string
expected string
expectError bool
}{
{
name: "valid string parameter",
params: map[string]interface{}{"name": "test-value"},
paramName: "name",
expected: "test-value",
expectError: false,
},
{
name: "missing parameter",
params: map[string]interface{}{},
paramName: "name",
expected: "",
expectError: false,
},
{
name: "empty string parameter",
params: map[string]interface{}{"name": ""},
paramName: "name",
expected: "",
expectError: false,
},
{
name: "wrong type parameter",
params: map[string]interface{}{"name": 123},
paramName: "name",
expected: "",
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result, err := OptionalParam[string](tc.params, tc.paramName)
if tc.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.expected, result)
}
})
}
}
func Test_RequiredInt(t *testing.T) {
tests := []struct {
name string
params map[string]interface{}
paramName string
expected int
expectError bool
}{
{
name: "valid number parameter",
params: map[string]interface{}{"count": float64(42)},
paramName: "count",
expected: 42,
expectError: false,
},
{
name: "missing parameter",
params: map[string]interface{}{},
paramName: "count",
expected: 0,
expectError: true,
},
{
name: "wrong type parameter",
params: map[string]interface{}{"count": "not-a-number"},
paramName: "count",
expected: 0,
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result, err := RequiredInt(tc.params, tc.paramName)
if tc.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.expected, result)
}
})
}
}
func Test_OptionalIntParam(t *testing.T) {
tests := []struct {
name string
params map[string]interface{}
paramName string
expected int
expectError bool
}{
{
name: "valid number parameter",
params: map[string]interface{}{"count": float64(42)},
paramName: "count",
expected: 42,
expectError: false,
},
{
name: "missing parameter",
params: map[string]interface{}{},
paramName: "count",
expected: 0,
expectError: false,
},
{
name: "zero value",
params: map[string]interface{}{"count": float64(0)},
paramName: "count",
expected: 0,
expectError: false,
},
{
name: "wrong type parameter",
params: map[string]interface{}{"count": "not-a-number"},
paramName: "count",
expected: 0,
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result, err := OptionalIntParam(tc.params, tc.paramName)
if tc.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.expected, result)
}
})
}
}
func Test_OptionalNumberParamWithDefault(t *testing.T) {
tests := []struct {
name string
params map[string]interface{}
paramName string
defaultVal int
expected int
expectError bool
}{
{
name: "valid number parameter",
params: map[string]interface{}{"count": float64(42)},
paramName: "count",
defaultVal: 10,
expected: 42,
expectError: false,
},
{
name: "missing parameter",
params: map[string]interface{}{},
paramName: "count",
defaultVal: 10,
expected: 10,
expectError: false,
},
{
name: "zero value",
params: map[string]interface{}{"count": float64(0)},
paramName: "count",
defaultVal: 10,
expected: 10,
expectError: false,
},
{
name: "wrong type parameter",
params: map[string]interface{}{"count": "not-a-number"},
paramName: "count",
defaultVal: 10,
expected: 0,
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result, err := OptionalIntParamWithDefault(tc.params, tc.paramName, tc.defaultVal)
if tc.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.expected, result)
}
})
}
}
func Test_OptionalBooleanParam(t *testing.T) {
tests := []struct {
name string
params map[string]interface{}
paramName string
expected bool
expectError bool
}{
{
name: "true value",
params: map[string]interface{}{"flag": true},
paramName: "flag",
expected: true,
expectError: false,
},
{
name: "false value",
params: map[string]interface{}{"flag": false},
paramName: "flag",
expected: false,
expectError: false,
},
{
name: "missing parameter",
params: map[string]interface{}{},
paramName: "flag",
expected: false,
expectError: false,
},
{
name: "wrong type parameter",
params: map[string]interface{}{"flag": "not-a-boolean"},
paramName: "flag",
expected: false,
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result, err := OptionalParam[bool](tc.params, tc.paramName)
if tc.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.expected, result)
}
})
}
}
func TestOptionalStringArrayParam(t *testing.T) {
tests := []struct {
name string
params map[string]interface{}
paramName string
expected []string
expectError bool
}{
{
name: "parameter not in request",
params: map[string]any{},
paramName: "flag",
expected: []string{},
expectError: false,
},
{
name: "valid any array parameter",
params: map[string]any{
"flag": []any{"v1", "v2"},
},
paramName: "flag",
expected: []string{"v1", "v2"},
expectError: false,
},
{
name: "valid string array parameter",
params: map[string]any{
"flag": []string{"v1", "v2"},
},
paramName: "flag",
expected: []string{"v1", "v2"},
expectError: false,
},
{
name: "wrong type parameter",
params: map[string]any{
"flag": 1,
},
paramName: "flag",
expected: []string{},
expectError: true,
},
{
name: "wrong slice type parameter",
params: map[string]any{
"flag": []any{"foo", 2},
},
paramName: "flag",
expected: []string{},
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result, err := OptionalStringArrayParam(tc.params, tc.paramName)
if tc.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.expected, result)
}
})
}
}
func TestOptionalPaginationParams(t *testing.T) {
tests := []struct {
name string
params map[string]any
expected PaginationParams
expectError bool
}{
{
name: "no pagination parameters, default values",
params: map[string]any{},
expected: PaginationParams{
Page: 1,
PerPage: 30,
},
expectError: false,
},
{
name: "page parameter, default perPage",
params: map[string]any{
"page": float64(2),
},
expected: PaginationParams{
Page: 2,
PerPage: 30,
},
expectError: false,
},
{
name: "perPage parameter, default page",
params: map[string]any{
"perPage": float64(50),
},
expected: PaginationParams{
Page: 1,
PerPage: 50,
},
expectError: false,
},
{
name: "page and perPage parameters",
params: map[string]any{
"page": float64(2),
"perPage": float64(50),
},
expected: PaginationParams{
Page: 2,
PerPage: 50,
},
expectError: false,
},
{
name: "invalid page parameter",
params: map[string]any{
"page": "not-a-number",
},
expected: PaginationParams{},
expectError: true,
},
{
name: "invalid perPage parameter",
params: map[string]any{
"perPage": "not-a-number",
},
expected: PaginationParams{},
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result, err := OptionalPaginationParams(tc.params)
if tc.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.expected, result)
}
})
}
}