package scanner
import (
"context"
"testing"
"time"
malysisv1pb "buf.build/gen/go/safedep/api/protocolbuffers/go/safedep/messages/malysis/v1"
packagev1 "buf.build/gen/go/safedep/api/protocolbuffers/go/safedep/messages/package/v1"
malysisv1 "buf.build/gen/go/safedep/api/protocolbuffers/go/safedep/services/malysis/v1"
"github.com/safedep/dry/adapters"
"github.com/safedep/vet/pkg/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
)
type mockMalwareAnalysisServiceClient struct {
mock.Mock
}
func (m *mockMalwareAnalysisServiceClient) QueryPackageAnalysis(
ctx context.Context,
in *malysisv1.QueryPackageAnalysisRequest,
opts ...grpc.CallOption,
) (*malysisv1.QueryPackageAnalysisResponse, error) {
args := m.Called(ctx, in, opts)
return args.Get(0).(*malysisv1.QueryPackageAnalysisResponse), args.Error(1)
}
func (m *mockMalwareAnalysisServiceClient) AnalyzePackage(
ctx context.Context,
in *malysisv1.AnalyzePackageRequest,
opts ...grpc.CallOption,
) (*malysisv1.AnalyzePackageResponse, error) {
args := m.Called(ctx, in, opts)
return args.Get(0).(*malysisv1.AnalyzePackageResponse), args.Error(1)
}
func (m *mockMalwareAnalysisServiceClient) GetAnalysisReport(
ctx context.Context,
in *malysisv1.GetAnalysisReportRequest,
opts ...grpc.CallOption,
) (*malysisv1.GetAnalysisReportResponse, error) {
args := m.Called(ctx, in, opts)
return args.Get(0).(*malysisv1.GetAnalysisReportResponse), args.Error(1)
}
func (m *mockMalwareAnalysisServiceClient) InternalAnalyzePackage(
ctx context.Context,
in *malysisv1.InternalAnalyzePackageRequest,
opts ...grpc.CallOption,
) (*malysisv1.InternalAnalyzePackageResponse, error) {
args := m.Called(ctx, in, opts)
return args.Get(0).(*malysisv1.InternalAnalyzePackageResponse), args.Error(1)
}
func (m *mockMalwareAnalysisServiceClient) ListPackageAnalysisRecords(
ctx context.Context,
in *malysisv1.ListPackageAnalysisRecordsRequest,
opts ...grpc.CallOption,
) (*malysisv1.ListPackageAnalysisRecordsResponse, error) {
args := m.Called(ctx, in, opts)
return args.Get(0).(*malysisv1.ListPackageAnalysisRecordsResponse), args.Error(1)
}
func (m *mockMalwareAnalysisServiceClient) InternalAgenticAnalyzePackage(
ctx context.Context,
in *malysisv1.InternalAgenticAnalyzePackageRequest,
opts ...grpc.CallOption,
) (*malysisv1.InternalAgenticAnalyzePackageResponse, error) {
args := m.Called(ctx, in, opts)
return args.Get(0).(*malysisv1.InternalAgenticAnalyzePackageResponse), args.Error(1)
}
func TestMalysisMalwareAnalysisQueryEnricherEnrich(t *testing.T) {
testCases := []struct {
name string
pkg *models.Package
mockResponse *malysisv1.QueryPackageAnalysisResponse
mockError error
expectedError bool
}{
{
name: "successful enrichment for npm package",
pkg: &models.Package{
PackageDetails: models.NewPackageDetail("npm", "test-package", "1.0.0"),
Manifest: &models.PackageManifest{
Ecosystem: models.EcosystemNpm,
},
},
mockResponse: &malysisv1.QueryPackageAnalysisResponse{
AnalysisId: "test-analysis-id",
Report: &malysisv1pb.Report{},
VerificationRecord: &malysisv1pb.VerificationRecord{},
},
mockError: nil,
expectedError: false,
},
{
name: "gRPC call returns error",
pkg: &models.Package{
PackageDetails: models.NewPackageDetail("maven", "test-package", "1.0.0"),
Manifest: &models.PackageManifest{
Ecosystem: models.EcosystemMaven,
},
},
mockResponse: nil,
mockError: assert.AnError,
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mockClient := &mockMalwareAnalysisServiceClient{}
mockGHA := &adapters.GithubClient{}
req := &malysisv1.QueryPackageAnalysisRequest{
Target: &malysisv1pb.PackageAnalysisTarget{
PackageVersion: &packagev1.PackageVersion{
Package: &packagev1.Package{
Ecosystem: tc.pkg.GetControlTowerSpecEcosystem(),
Name: tc.pkg.GetName(),
},
Version: tc.pkg.GetVersion(),
},
},
}
mockClient.On("QueryPackageAnalysis", mock.Anything, req, mock.Anything).
Return(tc.mockResponse, tc.mockError)
enricher := &malysisMalwareAnalysisQueryEnricher{
client: mockClient,
config: MalysisMalwareEnricherConfig{
GrpcOperationTimeout: 2 * time.Second,
},
gha: mockGHA,
}
err := enricher.Enrich(tc.pkg, func(_ *models.Package) error {
return nil
})
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotNil(t, tc.pkg.GetMalwareAnalysisResult())
assert.Equal(t, tc.mockResponse.GetAnalysisId(), tc.pkg.GetMalwareAnalysisResult().AnalysisId)
assert.Equal(t, tc.mockResponse.GetReport(), tc.pkg.GetMalwareAnalysisResult().Report)
assert.Equal(t, tc.mockResponse.GetVerificationRecord(), tc.pkg.GetMalwareAnalysisResult().VerificationRecord)
}
mockClient.AssertExpectations(t)
})
}
}
func TestNewMalysisMalwareAnalysisQueryEnricher(t *testing.T) {
// Test cases
testCases := []struct {
name string
cc *grpc.ClientConn
gha *adapters.GithubClient
config MalysisMalwareEnricherConfig
expectedError bool
}{
{
name: "nil client connection",
cc: nil,
gha: &adapters.GithubClient{},
config: MalysisMalwareEnricherConfig{},
expectedError: true,
},
{
name: "nil github client",
cc: &grpc.ClientConn{},
gha: nil,
config: MalysisMalwareEnricherConfig{},
expectedError: true,
},
{
name: "valid inputs",
cc: &grpc.ClientConn{},
gha: &adapters.GithubClient{},
config: MalysisMalwareEnricherConfig{},
expectedError: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
enricher, err := NewMalysisMalwareAnalysisQueryEnricher(tc.cc, tc.gha, tc.config)
if tc.expectedError {
assert.Error(t, err)
assert.Nil(t, enricher)
} else {
assert.NoError(t, err)
assert.NotNil(t, enricher)
assert.Equal(t, tc.cc, enricher.cc)
assert.Equal(t, tc.gha, enricher.gha)
assert.Equal(t, tc.config, enricher.config)
}
})
}
}