/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { GENKIT_CLIENT_HEADER } from 'genkit';
import { GoogleAuth } from 'google-auth-library';
import { PluginOptions } from './common/types.js';
function endpoint(options: {
projectId: string;
location: string;
model: string;
}) {
return (
`https://${options.location}-aiplatform.googleapis.com/v1/` +
`projects/${options.projectId}/locations/${options.location}/` +
`publishers/google/models/${options.model}:predict`
);
}
interface PredictionResponse<R> {
predictions: R[];
}
export type PredictClient<I = unknown, R = unknown, P = unknown> = (
instances: I[],
parameters: P
) => Promise<PredictionResponse<R>>;
export function predictModel<I = unknown, R = unknown, P = unknown>(
auth: GoogleAuth,
{ location, projectId }: PluginOptions,
model: string
): PredictClient<I, R, P> {
return async (
instances: I[],
parameters: P
): Promise<PredictionResponse<R>> => {
const fetch = (await import('node-fetch')).default;
const accessToken = await auth.getAccessToken();
const req = {
instances,
parameters,
};
const response = await fetch(
endpoint({
projectId: projectId!,
location,
model,
}),
{
method: 'POST',
body: JSON.stringify(req),
headers: {
Authorization: `Bearer ${accessToken}`,
'Content-Type': 'application/json',
'X-Goog-Api-Client': GENKIT_CLIENT_HEADER,
},
}
);
if (!response.ok) {
throw new Error(
`Error from Vertex AI predict: HTTP ${
response.status
}: ${await response.text()}`
);
}
return (await response.json()) as PredictionResponse<R>;
};
}