imagen.ts•4.96 kB
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { GenkitError, MessageData, z, type Genkit } from 'genkit';
import {
getBasicUsageStats,
modelRef,
type GenerateRequest,
type ModelAction,
type ModelInfo,
type ModelReference,
} from 'genkit/model';
import { getApiKeyFromEnvVar } from './common.js';
import { predictModel } from './predict.js';
export type KNOWN_IMAGEN_MODELS = 'imagen-3.0-generate-002';
/**
* See https://ai.google.dev/gemini-api/docs/image-generation#imagen-model
*/
export const ImagenConfigSchema = z
.object({
numberOfImages: z
.number()
.describe(
'The number of images to generate, from 1 to 4 (inclusive). The default is 1.'
)
.optional(),
aspectRatio: z
.enum(['1:1', '9:16', '16:9', '3:4', '4:3'])
.describe('Desired aspect ratio of the output image.')
.optional(),
personGeneration: z
.enum(['dont_allow', 'allow_adult', 'allow_all'])
.describe(
'Control if/how images of people will be generated by the model.'
)
.optional(),
})
.passthrough();
interface ImagenParameters {
sampleCount?: number;
aspectRatio?: string;
personGeneration?: string;
}
function toParameters(
request: GenerateRequest<typeof ImagenConfigSchema>
): ImagenParameters {
const out = {
sampleCount: request.config?.numberOfImages ?? 1,
...request?.config,
};
for (const k in out) {
if (!out[k]) delete out[k];
}
return out;
}
function extractText(request: GenerateRequest) {
return request.messages
.at(-1)!
.content.map((c) => c.text || '')
.join('');
}
function extractBaseImage(request: GenerateRequest): string | undefined {
return request.messages
.at(-1)
?.content.find((p) => !!p.media)
?.media?.url.split(',')[1];
}
interface ImagenPrediction {
predictions: { bytesBase64Encoded: string; mimeType: string }[];
}
interface ImagenInstance {
prompt: string;
image?: { bytesBase64Encoded: string };
mask?: { image?: { bytesBase64Encoded: string } };
}
export const GENERIC_IMAGEN_INFO = {
label: `Google AI - Generic Imagen`,
supports: {
media: true,
multiturn: false,
tools: false,
systemRole: false,
output: ['media'],
},
} as ModelInfo;
export function defineImagenModel(
ai: Genkit,
name: string,
apiKey?: string | false
): ModelAction {
if (apiKey !== false) {
apiKey = apiKey || getApiKeyFromEnvVar();
if (!apiKey) {
throw new GenkitError({
status: 'FAILED_PRECONDITION',
message:
'Please pass in the API key or set the GEMINI_API_KEY or GOOGLE_API_KEY environment variable.\n' +
'For more details see https://genkit.dev/docs/plugins/google-genai',
});
}
}
const modelName = `googleai/${name}`;
const model: ModelReference<z.ZodTypeAny> = modelRef({
name: modelName,
info: {
...GENERIC_IMAGEN_INFO,
label: `Google AI - ${name}`,
},
configSchema: ImagenConfigSchema,
});
return ai.defineModel(
{
name: modelName,
...model.info,
configSchema: ImagenConfigSchema,
},
async (request) => {
const instance: ImagenInstance = {
prompt: extractText(request),
};
const baseImage = extractBaseImage(request);
if (baseImage) {
instance.image = { bytesBase64Encoded: baseImage };
}
const predictClient = predictModel<
ImagenInstance,
ImagenPrediction,
ImagenParameters
>(model.version || name, apiKey as string, 'predict');
const response = await predictClient([instance], toParameters(request));
if (!response.predictions || response.predictions.length == 0) {
throw new Error(
'Model returned no predictions. Possibly due to content filters.'
);
}
const message = {
role: 'model',
content: [],
} as MessageData;
response.predictions.forEach((p, i) => {
const b64data = p.bytesBase64Encoded;
const mimeType = p.mimeType;
message.content.push({
media: {
url: `data:${mimeType};base64,${b64data}`,
contentType: mimeType,
},
});
});
return {
finishReason: 'stop',
message,
usage: getBasicUsageStats(request.messages, message),
custom: response,
};
}
);
}