imagen.ts•9.1 kB
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { ActionMetadata, modelActionMetadata, z } from 'genkit';
import {
GenerationCommonConfigSchema,
ModelAction,
ModelInfo,
ModelReference,
modelRef,
} from 'genkit/model';
import { model as pluginModel } from 'genkit/plugin';
import { imagenPredict } from './client.js';
import { fromImagenResponse, toImagenPredictRequest } from './converters.js';
import { ClientOptions, Model, VertexPluginOptions } from './types.js';
import { checkModelName, extractVersion, modelName } from './utils.js';
/**
* See https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api.
*/
export const ImagenConfigSchema = GenerationCommonConfigSchema.extend({
// TODO: Remove common config schema extension since Imagen models don't support
// most of the common config parameters. Also, add more parameters like sampleCount
// from the above reference.
language: z
.enum(['auto', 'en', 'es', 'hi', 'ja', 'ko', 'pt', 'zh-TW', 'zh', 'zh-CN'])
.describe('Language of the prompt text.')
.optional(),
aspectRatio: z
.enum(['1:1', '9:16', '16:9', '3:4', '4:3'])
.describe('Desired aspect ratio of the output image.')
.optional(),
negativePrompt: z
.string()
.describe(
'A description of what to discourage in the generated images. ' +
'For example: "animals" (removes animals), "blurry" ' +
'(makes the image clearer), "text" (removes text), or ' +
'"cropped" (removes cropped images).'
)
.optional(),
seed: z
.number()
.int()
.min(1)
.max(2147483647)
.describe(
'Controls the randomization of the image generation process. Use the ' +
'same seed across requests to provide consistency, or change it to ' +
'introduce variety in the response.'
)
.optional(),
location: z
.string()
.describe('Google Cloud region e.g. us-central1.')
.optional(),
personGeneration: z
.enum(['dont_allow', 'allow_adult', 'allow_all'])
.describe('Control if/how images of people will be generated by the model.')
.optional(),
safetySetting: z
.enum(['block_most', 'block_some', 'block_few', 'block_fewest'])
.describe('Adds a filter level to safety filtering.')
.optional(),
addWatermark: z
.boolean()
.describe('Add an invisible watermark to the generated images.')
.optional(),
storageUri: z
.string()
.describe('Cloud Storage URI to store the generated images.')
.optional(),
mode: z
.enum(['upscale'])
.describe('Mode must be set for upscaling requests.')
.optional(),
/**
* Describes the editing intention for the request.
*
* See https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api#edit_images_2 for details.
*/
editConfig: z
.object({
editMode: z
.enum([
'inpainting-insert',
'inpainting-remove',
'outpainting',
'product-image',
])
.describe('Editing intention for the request.')
.optional(),
maskMode: z
.object({
maskType: z
.enum(['background', 'foreground', 'semantic'])
.describe(
'"background" automatically generates a mask for all ' +
'regions except the primary subject(s) of the image, ' +
'"foreground" automatically generates a mask for the primary ' +
'subjects(s) of the image. "semantic" segments one or more ' +
'of the segmentation classes using class ID.'
),
classes: z
.array(z.number())
.describe('List of class IDs for segmentation.')
.length(5)
.optional(),
})
.describe(
'Prompts the model to generate a mask instead of you ' +
'needing to provide one. Consequently, when you provide ' +
'this parameter you can omit a mask object.'
)
.passthrough()
.optional(),
maskDilation: z
.number()
.describe('Dilation percentage of the mask provided.')
.min(0.0)
.max(1.0)
.optional(),
guidanceScale: z
.number()
.describe(
'Controls how much the model adheres to the text prompt. ' +
'Large values increase output and prompt alignment, but may ' +
'compromise image quality. Suggested values are 0-9 ' +
'(low strength), 10-20 (medium strength), 21+ (high strength).'
)
.optional(),
productPosition: z
.enum(['reposition', 'fixed'])
.describe(
'Defines whether the product should stay fixed or be ' +
'repositioned.'
)
.optional(),
})
.passthrough()
.optional(),
upscaleConfig: z
.object({
upscaleFactor: z
.enum(['x2', 'x4'])
.describe('The factor to upscale the image.'),
})
.describe('Configuration for upscaling.')
.passthrough()
.optional(),
}).passthrough();
export type ImagenConfigSchemaType = typeof ImagenConfigSchema;
export type ImagenConfig = z.infer<ImagenConfigSchemaType>;
// for commonRef
type ConfigSchemaType = ImagenConfigSchemaType;
function commonRef(
name: string,
info?: ModelInfo,
configSchema: ConfigSchemaType = ImagenConfigSchema
): ModelReference<ConfigSchemaType> {
return modelRef({
name: `vertexai/${name}`,
configSchema,
info: info ?? {
supports: {
media: true,
multiturn: false,
tools: false,
toolChoice: false,
systemRole: false,
output: ['media'],
},
},
});
}
// Allow all the capabilities for unknown future models
const GENERIC_MODEL = commonRef('imagen', {
supports: {
media: true,
multiturn: true,
tools: true,
systemRole: true,
output: ['media'],
},
});
export const KNOWN_MODELS = {
'imagen-3.0-generate-002': commonRef('imagen-3.0-generate-002'),
'imagen-3.0-generate-001': commonRef('imagen-3.0-generate-001'),
'imagen-3.0-capability-001': commonRef('imagen-3.0-capability-001'),
'imagen-3.0-fast-generate-001': commonRef('imagen-3.0-fast-generate-001'),
'imagen-4.0-generate-preview-06-06': commonRef(
'imagen-4.0-generate-preview-06-06'
),
'imagen-4.0-ultra-generate-preview-06-06': commonRef(
'imagen-4.0-ultra-generate-preview-06-06'
),
} as const;
export type KnownModels = keyof typeof KNOWN_MODELS;
export type ImagenModelName = `imagen=${string}`;
export function isImagenModelName(value?: string): value is ImagenModelName {
return !!value?.startsWith('imagen-');
}
export function model(
version: string,
config: ImagenConfig = {}
): ModelReference<typeof ImagenConfigSchema> {
const name = checkModelName(version);
if (KNOWN_MODELS[name]) {
return KNOWN_MODELS[name].withConfig(config);
}
return modelRef({
name: `vertexai/${name}`,
config,
configSchema: ImagenConfigSchema,
info: {
...GENERIC_MODEL.info,
},
});
}
export function listActions(models: Model[]): ActionMetadata[] {
return models
.filter((m: Model) => isImagenModelName(modelName(m.name)))
.map((m: Model) => {
const ref = model(m.name);
return modelActionMetadata({
name: ref.name,
info: ref.info,
configSchema: ref.configSchema,
});
});
}
export function listKnownModels(
clientOptions: ClientOptions,
pluginOptions?: VertexPluginOptions
) {
return Object.keys(KNOWN_MODELS).map((name: string) =>
defineModel(name, clientOptions, pluginOptions)
);
}
export function defineModel(
name: string,
clientOptions: ClientOptions,
pluginOptions?: VertexPluginOptions
): ModelAction {
const ref = model(name);
return pluginModel(
{
name: ref.name,
...ref.info,
configSchema: ref.configSchema,
},
async (request, { abortSignal }) => {
const clientOpt = { ...clientOptions, signal: abortSignal };
const imagenPredictRequest = toImagenPredictRequest(request);
const response = await imagenPredict(
extractVersion(ref),
imagenPredictRequest,
clientOpt
);
if (!response.predictions || response.predictions.length == 0) {
throw new Error(
'Model returned no predictions. Possibly due to content filters.'
);
}
return fromImagenResponse(response, request);
}
);
}
export const TEST_ONLY = {
GENERIC_MODEL,
KNOWN_MODELS,
};