Support for GPT-4-Vision (khanon/oai-reverse-proxy!54)

This commit is contained in:
khanon 2023-11-19 05:06:21 +00:00
parent 7f2f324e26
commit f29049f993
13 changed files with 198 additions and 52 deletions

View File

@ -1,6 +1,7 @@
import { RequestPreprocessor } from "./index";
import { countTokens, OpenAIPromptMessage } from "../../../shared/tokenization";
import { countTokens } from "../../../shared/tokenization";
import { assertNever } from "../../../shared/utils";
import type { OpenAIChatMessage } from "./transform-outbound-payload";
/**
* Given a request with an already-transformed body, counts the number of
@ -13,7 +14,7 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
switch (service) {
case "openai": {
req.outputTokens = req.body.max_tokens;
const prompt: OpenAIPromptMessage[] = req.body.messages;
const prompt: OpenAIChatMessage[] = req.body.messages;
result = await countTokens({ req, prompt, service });
break;
}

View File

@ -3,6 +3,7 @@ import { config } from "../../../config";
import { assertNever } from "../../../shared/utils";
import { RequestPreprocessor } from ".";
import { UserInputError } from "../../../shared/errors";
import { OpenAIChatMessage } from "./transform-outbound-payload";
const rejectedClients = new Map<string, number>();
@ -53,9 +54,16 @@ function getPromptFromRequest(req: Request) {
return body.prompt;
case "openai":
return body.messages
.map(
(m: { content: string; role: string }) => `${m.role}: ${m.content}`
)
.map((msg: OpenAIChatMessage) => {
const text = Array.isArray(msg.content)
? msg.content
.map((c) => {
if ("text" in c) return c.text;
})
.join()
: msg.content;
return `${msg.role}: ${text}`;
})
.join("\n\n");
case "openai-text":
case "openai-image":

View File

@ -1,7 +1,6 @@
import { Request } from "express";
import { z } from "zod";
import { config } from "../../../config";
import { OpenAIPromptMessage } from "../../../shared/tokenization";
import { isTextGenerationRequest, isImageGenerationRequest } from "../common";
import { RequestPreprocessor } from ".";
import { APIFormat } from "../../../shared/key-management";
@ -9,6 +8,8 @@ import { APIFormat } from "../../../shared/key-management";
const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic;
const OPENAI_OUTPUT_MAX = config.maxOutputTokensOpenAI;
// TODO: move schemas to shared
// https://console.anthropic.com/docs/api/reference#-v1-complete
export const AnthropicV1CompleteSchema = z.object({
model: z.string(),
@ -29,12 +30,25 @@ export const AnthropicV1CompleteSchema = z.object({
});
// https://platform.openai.com/docs/api-reference/chat/create
const OpenAIV1ChatCompletionSchema = z.object({
const OpenAIV1ChatContentArraySchema = z.array(
z.union([
z.object({ type: z.literal("text"), text: z.string() }),
z.object({
type: z.literal("image_url"),
image_url: z.object({
url: z.string().url(),
detail: z.enum(["low", "auto", "high"]).optional().default("auto"),
}),
}),
])
);
export const OpenAIV1ChatCompletionSchema = z.object({
model: z.string(),
messages: z.array(
z.object({
role: z.enum(["system", "user", "assistant"]),
content: z.string(),
content: z.union([z.string(), OpenAIV1ChatContentArraySchema]),
name: z.string().optional(),
}),
{
@ -68,6 +82,10 @@ const OpenAIV1ChatCompletionSchema = z.object({
seed: z.number().int().optional(),
});
export type OpenAIChatMessage = z.infer<
typeof OpenAIV1ChatCompletionSchema
>["messages"][0];
const OpenAIV1TextCompletionSchema = z
.object({
model: z
@ -232,7 +250,7 @@ function openaiToOpenaiText(req: Request) {
}
const { messages, ...rest } = result.data;
const prompt = flattenOpenAiChatMessages(messages);
const prompt = flattenOpenAIChatMessages(messages);
let stops = rest.stop
? Array.isArray(rest.stop)
@ -260,6 +278,9 @@ function openaiToOpenaiImage(req: Request) {
const { messages } = result.data;
const prompt = messages.filter((m) => m.role === "user").pop()?.content;
if (Array.isArray(prompt)) {
throw new Error("Image generation prompt must be a text message.");
}
if (body.stream) {
throw new Error(
@ -304,7 +325,7 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
}
const { messages, ...rest } = result.data;
const prompt = flattenOpenAiChatMessages(messages);
const prompt = flattenOpenAIChatMessages(messages);
let stops = rest.stop
? Array.isArray(rest.stop)
@ -336,7 +357,7 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
};
}
export function openAIMessagesToClaudePrompt(messages: OpenAIPromptMessage[]) {
export function openAIMessagesToClaudePrompt(messages: OpenAIChatMessage[]) {
return (
messages
.map((m) => {
@ -348,17 +369,17 @@ export function openAIMessagesToClaudePrompt(messages: OpenAIPromptMessage[]) {
} else if (role === "user") {
role = "Human";
}
const name = m.name?.trim();
const content = flattenOpenAIMessageContent(m.content);
// https://console.anthropic.com/docs/prompt-design
// `name` isn't supported by Anthropic but we can still try to use it.
return `\n\n${role}: ${m.name?.trim() ? `(as ${m.name}) ` : ""}${
m.content
}`;
return `\n\n${role}: ${name ? `(as ${name}) ` : ""}${content}`;
})
.join("") + "\n\nAssistant:"
);
}
function flattenOpenAiChatMessages(messages: OpenAIPromptMessage[]) {
function flattenOpenAIChatMessages(messages: OpenAIChatMessage[]) {
// Temporary to allow experimenting with prompt strategies
const PROMPT_VERSION: number = 1;
switch (PROMPT_VERSION) {
@ -375,7 +396,7 @@ function flattenOpenAiChatMessages(messages: OpenAIPromptMessage[]) {
} else if (role === "user") {
role = "User";
}
return `\n\n${role}: ${m.content}`;
return `\n\n${role}: ${flattenOpenAIMessageContent(m.content)}`;
})
.join("") + "\n\nAssistant:"
);
@ -387,10 +408,23 @@ function flattenOpenAiChatMessages(messages: OpenAIPromptMessage[]) {
if (role === "system") {
role = "System: ";
}
return `\n\n${role}${m.content}`;
return `\n\n${role}${flattenOpenAIMessageContent(m.content)}`;
})
.join("");
default:
throw new Error(`Unknown prompt version: ${PROMPT_VERSION}`);
}
}
function flattenOpenAIMessageContent(
content: OpenAIChatMessage["content"]
): string {
return Array.isArray(content)
? content
.map((contentItem) => {
if ("text" in contentItem) return contentItem.text;
if ("image_url" in contentItem) return "[ Uploaded Image Omitted ]";
})
.join("\n")
: content;
}

View File

@ -9,6 +9,7 @@ import {
} from "../common";
import { ProxyResHandlerWithBody } from ".";
import { assertNever } from "../../../shared/utils";
import { OpenAIChatMessage } from "../request/transform-outbound-payload";
/** If prompt logging is enabled, enqueues the prompt for logging. */
export const logPrompt: ProxyResHandlerWithBody = async (
@ -42,11 +43,6 @@ export const logPrompt: ProxyResHandlerWithBody = async (
});
};
type OaiMessage = {
role: "user" | "assistant" | "system";
content: string;
};
type OaiImageResult = {
prompt: string;
size: string;
@ -58,7 +54,7 @@ type OaiImageResult = {
const getPromptForRequest = (
req: Request,
responseBody: Record<string, any>
): string | OaiMessage[] | OaiImageResult => {
): string | OpenAIChatMessage[] | OaiImageResult => {
// Since the prompt logger only runs after the request has been proxied, we
// can assume the body has already been transformed to the target API's
// format.
@ -85,13 +81,25 @@ const getPromptForRequest = (
};
const flattenMessages = (
val: string | OaiMessage[] | OaiImageResult
val: string | OpenAIChatMessage[] | OaiImageResult
): string => {
if (typeof val === "string") {
return val.trim();
}
if (Array.isArray(val)) {
return val.map((m) => `${m.role}: ${m.content}`).join("\n");
return val
.map(({ content, role }) => {
const text = Array.isArray(content)
? content
.map((c) => {
if ("text" in c) return c.text;
if ("image_url" in c) return "(( Attached Image ))";
})
.join("\n")
: content;
return `${role}: ${text}`;
})
.join("\n");
}
return val.prompt.trim();
};

View File

@ -26,6 +26,7 @@ import { createOnProxyResHandler, ProxyResHandlerWithBody } from "./middleware/r
// https://platform.openai.com/docs/models/overview
const KNOWN_MODELS = [
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4",
"gpt-4-0613",
"gpt-4-0314", // EOL 2024-06-13

View File

@ -475,7 +475,7 @@ export function registerHeartbeat(req: Request) {
const res = req.res!;
const currentSize = getHeartbeatSize();
req.log.info({
req.log.debug({
currentSize,
HEARTBEAT_INTERVAL,
PAYLOAD_SCALE_FACTOR,

View File

@ -17,8 +17,8 @@ proxyRouter.use((req, _res, next) => {
next();
});
proxyRouter.use(
express.json({ limit: "1536kb" }),
express.urlencoded({ extended: true, limit: "1536kb" })
express.json({ limit: "10mb" }),
express.urlencoded({ extended: true, limit: "10mb" })
);
proxyRouter.use(gatekeeper);
proxyRouter.use(checkRisuToken);

View File

@ -0,0 +1,6 @@
// We need to control the timing of when sharp is imported because it has a
// native dependency that causes conflicts with node-canvas if they are not
// imported in a specific order.
import sharp from "sharp";
export { sharp as libSharp };

View File

@ -3,11 +3,9 @@ import { promises as fs } from "fs";
import path from "path";
import { v4 } from "uuid";
import { USER_ASSETS_DIR } from "../../config";
import { logger } from "../../logger";
import { addToImageHistory } from "./image-history";
import sharp from "sharp";
import { libSharp } from "./index";
const log = logger.child({ module: "file-storage" });
export type OpenAIImageGenerationResult = {
created: number;
@ -40,7 +38,7 @@ async function saveB64Image(b64: string) {
async function createThumbnail(filepath: string) {
const thumbnailPath = filepath.replace(/(\.[\wd_-]+)$/i, "_t.jpg");
await sharp(filepath)
await libSharp(filepath)
.resize(150, 150, {
fit: "inside",
withoutEnlargement: true,

View File

@ -27,6 +27,7 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
"^gpt-4-1106(-preview)?$": "gpt4-turbo",
"^gpt-4(-\\d{4})?-vision(-preview)?$": "gpt4-turbo",
"^gpt-4-32k-\\d{4}$": "gpt4-32k",
"^gpt-4-32k$": "gpt4-32k",
"^gpt-4-\\d{4}$": "gpt4",

View File

@ -1,2 +1 @@
export { OpenAIPromptMessage } from "./openai";
export { init, countTokens } from "./tokenizer";

View File

@ -1,5 +1,11 @@
import { Tiktoken } from "tiktoken/lite";
import cl100k_base from "tiktoken/encoders/cl100k_base.json";
import { logger } from "../../logger";
import { libSharp } from "../file-storage";
import type { OpenAIChatMessage } from "../../proxy/middleware/request/transform-outbound-payload";
const log = logger.child({ module: "tokenizer", service: "openai" });
const GPT4_VISION_SYSTEM_PROMPT_SIZE = 170;
let encoder: Tiktoken;
@ -15,8 +21,8 @@ export function init() {
// Tested against:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
export function getTokenCount(
prompt: string | OpenAIPromptMessage[],
export async function getTokenCount(
prompt: string | OpenAIChatMessage[],
model: string
) {
if (typeof prompt === "string") {
@ -24,31 +30,49 @@ export function getTokenCount(
}
const gpt4 = model.startsWith("gpt-4");
const vision = model.includes("vision");
const tokensPerMessage = gpt4 ? 3 : 4;
const tokensPerName = gpt4 ? 1 : -1; // turbo omits role if name is present
let numTokens = 0;
let numTokens = vision ? GPT4_VISION_SYSTEM_PROMPT_SIZE : 0;
for (const message of prompt) {
numTokens += tokensPerMessage;
for (const key of Object.keys(message)) {
{
const value = message[key as keyof OpenAIPromptMessage];
if (!value || typeof value !== "string") continue;
let textContent: string = "";
const value = message[key as keyof OpenAIChatMessage];
if (!value) continue;
if (Array.isArray(value)) {
for (const item of value) {
if (item.type === "text") {
textContent += item.text;
} else if (item.type === "image_url") {
const { url, detail } = item.image_url;
const cost = await getGpt4VisionTokenCost(url, detail);
numTokens += cost ?? 0;
}
}
} else {
textContent = value;
}
// Break if we get a huge message or exceed the token limit to prevent
// DoS.
// 100k tokens allows for future 100k GPT-4 models and 500k characters
// 200k tokens allows for future 200k GPT-4 models and 500k characters
// is just a sanity check
if (value.length > 500000 || numTokens > 100000) {
numTokens = 100000;
if (textContent.length > 500000 || numTokens > 200000) {
numTokens = 200000;
return {
tokenizer: "tiktoken (prompt length limit exceeded)",
token_count: numTokens,
};
}
numTokens += encoder.encode(value).length;
numTokens += encoder.encode(textContent).length;
if (key === "name") {
numTokens += tokensPerName;
}
@ -59,6 +83,78 @@ export function getTokenCount(
return { tokenizer: "tiktoken", token_count: numTokens };
}
async function getGpt4VisionTokenCost(
url: string,
detail: "auto" | "low" | "high" = "auto"
) {
// For now we do not allow remote images as the proxy would have to download
// them, which is a potential DoS vector.
if (!url.startsWith("data:image/")) {
throw new Error(
"Remote images are not supported. Add the image to your prompt as a base64 data URL."
);
}
const base64Data = url.split(",")[1];
const buffer = Buffer.from(base64Data, "base64");
const image = libSharp(buffer);
const metadata = await image.metadata();
if (!metadata || !metadata.width || !metadata.height) {
throw new Error("Prompt includes an image that could not be parsed");
}
const { width, height } = metadata;
let selectedDetail: "low" | "high";
if (detail === "auto") {
const threshold = 512 * 512;
const imageSize = width * height;
selectedDetail = imageSize > threshold ? "high" : "low";
} else {
selectedDetail = detail;
}
// https://platform.openai.com/docs/guides/vision/calculating-costs
if (selectedDetail === "low") {
log.info(
{ width, height, tokens: 85 },
"Using fixed GPT-4-Vision token cost for low detail image"
);
return 85;
}
let newWidth = width;
let newHeight = height;
if (width > 2048 || height > 2048) {
const aspectRatio = width / height;
if (width > height) {
newWidth = 2048;
newHeight = Math.round(2048 / aspectRatio);
} else {
newHeight = 2048;
newWidth = Math.round(2048 * aspectRatio);
}
}
if (newWidth < newHeight) {
newHeight = Math.round((newHeight / newWidth) * 768);
newWidth = 768;
} else {
newWidth = Math.round((newWidth / newHeight) * 768);
newHeight = 768;
}
const tiles = Math.ceil(newWidth / 512) * Math.ceil(newHeight / 512);
const tokens = 170 * tiles + 85;
log.info(
{ width, height, newWidth, newHeight, tiles, tokens },
"Calculated GPT-4-Vision token cost for high detail image"
);
return tokens;
}
function getTextTokenCount(prompt: string) {
if (prompt.length > 500000) {
return {
@ -73,12 +169,6 @@ function getTextTokenCount(prompt: string) {
};
}
export type OpenAIPromptMessage = {
name?: string;
content: string;
role: string;
};
// Model Resolution Price
// DALL·E 3 1024×1024 $0.040 / image
// 1024×1792, 1792×1024 $0.080 / image

View File

@ -1,4 +1,5 @@
import { Request } from "express";
import type { OpenAIChatMessage } from "../../proxy/middleware/request/transform-outbound-payload";
import { assertNever } from "../utils";
import {
init as initClaude,
@ -7,7 +8,6 @@ import {
import {
init as initOpenAi,
getTokenCount as getOpenAITokenCount,
OpenAIPromptMessage,
getOpenAIImageCost,
} from "./openai";
import { APIFormat } from "../key-management";
@ -20,7 +20,7 @@ export async function init() {
/** Tagged union via `service` field of the different types of requests that can
* be made to the tokenization service, for both prompts and completions */
type TokenCountRequest = { req: Request } & (
| { prompt: OpenAIPromptMessage[]; completion?: never; service: "openai" }
| { prompt: OpenAIChatMessage[]; completion?: never; service: "openai" }
| {
prompt: string;
completion?: never;
@ -52,7 +52,7 @@ export async function countTokens({
case "openai":
case "openai-text":
return {
...getOpenAITokenCount(prompt ?? completion, req.body.model),
...(await getOpenAITokenCount(prompt ?? completion, req.body.model)),
tokenization_duration_ms: getElapsedMs(time),
};
case "openai-image":
@ -69,7 +69,7 @@ export async function countTokens({
// TODO: Can't find a tokenization library for PaLM. There is an API
// endpoint for it but it adds significant latency to the request.
return {
...getOpenAITokenCount(prompt ?? completion, req.body.model),
...(await getOpenAITokenCount(prompt ?? completion, req.body.model)),
tokenization_duration_ms: getElapsedMs(time),
};
default: