imagen.ts•6.89 kB
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {
ActionMetadata,
MediaPart,
MessageData,
modelActionMetadata,
z,
} from 'genkit';
import {
getBasicUsageStats,
modelRef,
type GenerateRequest,
type ModelAction,
type ModelInfo,
type ModelReference,
} from 'genkit/model';
import { model as pluginModel } from 'genkit/plugin';
import { imagenPredict } from './client.js';
import type {
ClientOptions,
GoogleAIPluginOptions,
ImagenParameters,
ImagenPredictRequest,
ImagenPrediction,
Model,
} from './types.js';
import {
calculateApiKey,
checkApiKey,
checkModelName,
extractImagenImage,
extractText,
extractVersion,
modelName,
} from './utils.js';
/**
* See https://ai.google.dev/gemini-api/docs/image-generation#imagen-model
*/
export const ImagenConfigSchema = z
.object({
apiKey: z
.string()
.describe('Override the API key provided at plugin initialization.')
.optional(),
numberOfImages: z
.number()
.describe(
'The number of images to generate, from 1 to 4 (inclusive). The default is 1.'
)
.optional(),
aspectRatio: z
.enum(['1:1', '9:16', '16:9', '3:4', '4:3'])
.describe('Desired aspect ratio of the output image.')
.optional(),
personGeneration: z
.enum(['dont_allow', 'allow_adult', 'allow_all'])
.describe(
'Control if/how images of people will be generated by the model.'
)
.optional(),
})
.passthrough();
export type ImagenConfigSchemaType = typeof ImagenConfigSchema;
export type ImagenConfig = z.infer<ImagenConfigSchemaType>;
// This contains all the schemas for imagen models.
type ConfigSchemaType = ImagenConfigSchemaType;
function commonRef(
name: string,
info?: ModelInfo,
configSchema: ConfigSchemaType = ImagenConfigSchema
): ModelReference<ConfigSchemaType> {
return modelRef({
name: `googleai/${name}`,
configSchema,
info: info ?? {
supports: {
media: true,
multiturn: false,
tools: false,
toolChoice: false,
systemRole: false,
output: ['media'],
},
},
});
}
// Allow all the capabilities for unknown future models
const GENERIC_MODEL = commonRef('imagen', {
supports: {
media: true,
multiturn: true,
tools: true,
systemRole: true,
output: ['media'],
},
});
const KNOWN_MODELS = {
'imagen-3.0-generate-002': commonRef('imagen-3.0-generate-002'),
'imagen-4.0-generate-preview-06-06': commonRef(
'imagen-4.0-generate-preview-06-06'
),
'imagen-4.0-ultra-generate-preview-06-06': commonRef(
'imagen-4.0-ultra-generate-preview-06-06'
),
} as const;
export type KnownModels = keyof typeof KNOWN_MODELS; // For autocomplete
// For conditional types in index.ts model()
export type ImagenModelName = `imagen-${string}`;
export function isImagenModelName(value?: string): value is ImagenModelName {
return !!value?.startsWith('imagen-');
}
export function model(
version: string,
config: ImagenConfig = {}
): ModelReference<ConfigSchemaType> {
const name = checkModelName(version);
if (KNOWN_MODELS[name]) {
return KNOWN_MODELS[name].withConfig(config);
}
return modelRef({
name: `googleai/${name}`,
config,
configSchema: ImagenConfigSchema,
info: {
...GENERIC_MODEL.info,
},
});
}
export function listActions(models: Model[]): ActionMetadata[] {
return models
.filter(
(m) =>
m.supportedGenerationMethods.includes('predict') &&
isImagenModelName(modelName(m.name))
)
.filter((m) => !m.description || !m.description.includes('deprecated'))
.map((m) => {
const ref = model(m.name);
return modelActionMetadata({
name: ref.name,
info: ref.info,
configSchema: ref.configSchema,
});
});
}
export function listKnownModels(options?: GoogleAIPluginOptions) {
return Object.keys(KNOWN_MODELS).map((name: string) =>
defineModel(name, options)
);
}
export function defineModel(
name: string,
pluginOptions?: GoogleAIPluginOptions
): ModelAction {
checkApiKey(pluginOptions?.apiKey);
const ref = model(name);
const clientOptions: ClientOptions = {
apiVersion: pluginOptions?.apiVersion,
baseUrl: pluginOptions?.baseUrl,
};
return pluginModel(
{
name: ref.name,
...ref.info,
configSchema: ref.configSchema,
},
async (request, { abortSignal }) => {
const clientOpt = { ...clientOptions, signal: abortSignal };
const imagenPredictRequest: ImagenPredictRequest = {
instances: [
{
prompt: extractText(request),
image: extractImagenImage(request),
},
],
parameters: toImagenParameters(request),
};
const predictApiKey = calculateApiKey(
pluginOptions?.apiKey,
request.config?.apiKey
);
const response = await imagenPredict(
predictApiKey,
extractVersion(ref),
imagenPredictRequest,
clientOpt
);
if (!response.predictions || response.predictions.length == 0) {
throw new Error(
'Model returned no predictions. Possibly due to content filters.'
);
}
const message: MessageData = {
role: 'model',
content: response.predictions.map(fromImagenPrediction),
};
return {
finishReason: 'stop',
message,
usage: getBasicUsageStats(request.messages, message),
custom: response,
};
}
);
}
function fromImagenPrediction(p: ImagenPrediction): MediaPart {
const b64data = p.bytesBase64Encoded;
const mimeType = p.mimeType;
return {
media: {
url: `data:${mimeType};base64,${b64data}`,
contentType: mimeType,
},
};
}
function toImagenParameters(
request: GenerateRequest<typeof ImagenConfigSchema>
): ImagenParameters {
const out = {
sampleCount: request.config?.numberOfImages ?? 1,
...request?.config,
};
for (const k in out) {
if (!out[k]) delete out[k];
}
// This is not part of the request parameters sent to the endpoint
// It's pulled out and used separately
delete out.apiKey;
return out;
}
export const TEST_ONLY = {
toImagenParameters,
fromImagenPrediction,
GENERIC_MODEL,
KNOWN_MODELS,
};