Support for GPT-4-Vision (khanon/oai-reverse-proxy!54)
This commit is contained in:
parent
7f2f324e26
commit
f29049f993
|
@ -1,6 +1,7 @@
|
|||
import { RequestPreprocessor } from "./index";
|
||||
import { countTokens, OpenAIPromptMessage } from "../../../shared/tokenization";
|
||||
import { countTokens } from "../../../shared/tokenization";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import type { OpenAIChatMessage } from "./transform-outbound-payload";
|
||||
|
||||
/**
|
||||
* Given a request with an already-transformed body, counts the number of
|
||||
|
@ -13,7 +14,7 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
|
|||
switch (service) {
|
||||
case "openai": {
|
||||
req.outputTokens = req.body.max_tokens;
|
||||
const prompt: OpenAIPromptMessage[] = req.body.messages;
|
||||
const prompt: OpenAIChatMessage[] = req.body.messages;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ import { config } from "../../../config";
|
|||
import { assertNever } from "../../../shared/utils";
|
||||
import { RequestPreprocessor } from ".";
|
||||
import { UserInputError } from "../../../shared/errors";
|
||||
import { OpenAIChatMessage } from "./transform-outbound-payload";
|
||||
|
||||
const rejectedClients = new Map<string, number>();
|
||||
|
||||
|
@ -53,9 +54,16 @@ function getPromptFromRequest(req: Request) {
|
|||
return body.prompt;
|
||||
case "openai":
|
||||
return body.messages
|
||||
.map(
|
||||
(m: { content: string; role: string }) => `${m.role}: ${m.content}`
|
||||
)
|
||||
.map((msg: OpenAIChatMessage) => {
|
||||
const text = Array.isArray(msg.content)
|
||||
? msg.content
|
||||
.map((c) => {
|
||||
if ("text" in c) return c.text;
|
||||
})
|
||||
.join()
|
||||
: msg.content;
|
||||
return `${msg.role}: ${text}`;
|
||||
})
|
||||
.join("\n\n");
|
||||
case "openai-text":
|
||||
case "openai-image":
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import { Request } from "express";
|
||||
import { z } from "zod";
|
||||
import { config } from "../../../config";
|
||||
import { OpenAIPromptMessage } from "../../../shared/tokenization";
|
||||
import { isTextGenerationRequest, isImageGenerationRequest } from "../common";
|
||||
import { RequestPreprocessor } from ".";
|
||||
import { APIFormat } from "../../../shared/key-management";
|
||||
|
@ -9,6 +8,8 @@ import { APIFormat } from "../../../shared/key-management";
|
|||
const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic;
|
||||
const OPENAI_OUTPUT_MAX = config.maxOutputTokensOpenAI;
|
||||
|
||||
// TODO: move schemas to shared
|
||||
|
||||
// https://console.anthropic.com/docs/api/reference#-v1-complete
|
||||
export const AnthropicV1CompleteSchema = z.object({
|
||||
model: z.string(),
|
||||
|
@ -29,12 +30,25 @@ export const AnthropicV1CompleteSchema = z.object({
|
|||
});
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/chat/create
|
||||
const OpenAIV1ChatCompletionSchema = z.object({
|
||||
const OpenAIV1ChatContentArraySchema = z.array(
|
||||
z.union([
|
||||
z.object({ type: z.literal("text"), text: z.string() }),
|
||||
z.object({
|
||||
type: z.literal("image_url"),
|
||||
image_url: z.object({
|
||||
url: z.string().url(),
|
||||
detail: z.enum(["low", "auto", "high"]).optional().default("auto"),
|
||||
}),
|
||||
}),
|
||||
])
|
||||
);
|
||||
|
||||
export const OpenAIV1ChatCompletionSchema = z.object({
|
||||
model: z.string(),
|
||||
messages: z.array(
|
||||
z.object({
|
||||
role: z.enum(["system", "user", "assistant"]),
|
||||
content: z.string(),
|
||||
content: z.union([z.string(), OpenAIV1ChatContentArraySchema]),
|
||||
name: z.string().optional(),
|
||||
}),
|
||||
{
|
||||
|
@ -68,6 +82,10 @@ const OpenAIV1ChatCompletionSchema = z.object({
|
|||
seed: z.number().int().optional(),
|
||||
});
|
||||
|
||||
export type OpenAIChatMessage = z.infer<
|
||||
typeof OpenAIV1ChatCompletionSchema
|
||||
>["messages"][0];
|
||||
|
||||
const OpenAIV1TextCompletionSchema = z
|
||||
.object({
|
||||
model: z
|
||||
|
@ -232,7 +250,7 @@ function openaiToOpenaiText(req: Request) {
|
|||
}
|
||||
|
||||
const { messages, ...rest } = result.data;
|
||||
const prompt = flattenOpenAiChatMessages(messages);
|
||||
const prompt = flattenOpenAIChatMessages(messages);
|
||||
|
||||
let stops = rest.stop
|
||||
? Array.isArray(rest.stop)
|
||||
|
@ -260,6 +278,9 @@ function openaiToOpenaiImage(req: Request) {
|
|||
|
||||
const { messages } = result.data;
|
||||
const prompt = messages.filter((m) => m.role === "user").pop()?.content;
|
||||
if (Array.isArray(prompt)) {
|
||||
throw new Error("Image generation prompt must be a text message.");
|
||||
}
|
||||
|
||||
if (body.stream) {
|
||||
throw new Error(
|
||||
|
@ -304,7 +325,7 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
|
|||
}
|
||||
|
||||
const { messages, ...rest } = result.data;
|
||||
const prompt = flattenOpenAiChatMessages(messages);
|
||||
const prompt = flattenOpenAIChatMessages(messages);
|
||||
|
||||
let stops = rest.stop
|
||||
? Array.isArray(rest.stop)
|
||||
|
@ -336,7 +357,7 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
|
|||
};
|
||||
}
|
||||
|
||||
export function openAIMessagesToClaudePrompt(messages: OpenAIPromptMessage[]) {
|
||||
export function openAIMessagesToClaudePrompt(messages: OpenAIChatMessage[]) {
|
||||
return (
|
||||
messages
|
||||
.map((m) => {
|
||||
|
@ -348,17 +369,17 @@ export function openAIMessagesToClaudePrompt(messages: OpenAIPromptMessage[]) {
|
|||
} else if (role === "user") {
|
||||
role = "Human";
|
||||
}
|
||||
const name = m.name?.trim();
|
||||
const content = flattenOpenAIMessageContent(m.content);
|
||||
// https://console.anthropic.com/docs/prompt-design
|
||||
// `name` isn't supported by Anthropic but we can still try to use it.
|
||||
return `\n\n${role}: ${m.name?.trim() ? `(as ${m.name}) ` : ""}${
|
||||
m.content
|
||||
}`;
|
||||
return `\n\n${role}: ${name ? `(as ${name}) ` : ""}${content}`;
|
||||
})
|
||||
.join("") + "\n\nAssistant:"
|
||||
);
|
||||
}
|
||||
|
||||
function flattenOpenAiChatMessages(messages: OpenAIPromptMessage[]) {
|
||||
function flattenOpenAIChatMessages(messages: OpenAIChatMessage[]) {
|
||||
// Temporary to allow experimenting with prompt strategies
|
||||
const PROMPT_VERSION: number = 1;
|
||||
switch (PROMPT_VERSION) {
|
||||
|
@ -375,7 +396,7 @@ function flattenOpenAiChatMessages(messages: OpenAIPromptMessage[]) {
|
|||
} else if (role === "user") {
|
||||
role = "User";
|
||||
}
|
||||
return `\n\n${role}: ${m.content}`;
|
||||
return `\n\n${role}: ${flattenOpenAIMessageContent(m.content)}`;
|
||||
})
|
||||
.join("") + "\n\nAssistant:"
|
||||
);
|
||||
|
@ -387,10 +408,23 @@ function flattenOpenAiChatMessages(messages: OpenAIPromptMessage[]) {
|
|||
if (role === "system") {
|
||||
role = "System: ";
|
||||
}
|
||||
return `\n\n${role}${m.content}`;
|
||||
return `\n\n${role}${flattenOpenAIMessageContent(m.content)}`;
|
||||
})
|
||||
.join("");
|
||||
default:
|
||||
throw new Error(`Unknown prompt version: ${PROMPT_VERSION}`);
|
||||
}
|
||||
}
|
||||
|
||||
function flattenOpenAIMessageContent(
|
||||
content: OpenAIChatMessage["content"]
|
||||
): string {
|
||||
return Array.isArray(content)
|
||||
? content
|
||||
.map((contentItem) => {
|
||||
if ("text" in contentItem) return contentItem.text;
|
||||
if ("image_url" in contentItem) return "[ Uploaded Image Omitted ]";
|
||||
})
|
||||
.join("\n")
|
||||
: content;
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import {
|
|||
} from "../common";
|
||||
import { ProxyResHandlerWithBody } from ".";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import { OpenAIChatMessage } from "../request/transform-outbound-payload";
|
||||
|
||||
/** If prompt logging is enabled, enqueues the prompt for logging. */
|
||||
export const logPrompt: ProxyResHandlerWithBody = async (
|
||||
|
@ -42,11 +43,6 @@ export const logPrompt: ProxyResHandlerWithBody = async (
|
|||
});
|
||||
};
|
||||
|
||||
type OaiMessage = {
|
||||
role: "user" | "assistant" | "system";
|
||||
content: string;
|
||||
};
|
||||
|
||||
type OaiImageResult = {
|
||||
prompt: string;
|
||||
size: string;
|
||||
|
@ -58,7 +54,7 @@ type OaiImageResult = {
|
|||
const getPromptForRequest = (
|
||||
req: Request,
|
||||
responseBody: Record<string, any>
|
||||
): string | OaiMessage[] | OaiImageResult => {
|
||||
): string | OpenAIChatMessage[] | OaiImageResult => {
|
||||
// Since the prompt logger only runs after the request has been proxied, we
|
||||
// can assume the body has already been transformed to the target API's
|
||||
// format.
|
||||
|
@ -85,13 +81,25 @@ const getPromptForRequest = (
|
|||
};
|
||||
|
||||
const flattenMessages = (
|
||||
val: string | OaiMessage[] | OaiImageResult
|
||||
val: string | OpenAIChatMessage[] | OaiImageResult
|
||||
): string => {
|
||||
if (typeof val === "string") {
|
||||
return val.trim();
|
||||
}
|
||||
if (Array.isArray(val)) {
|
||||
return val.map((m) => `${m.role}: ${m.content}`).join("\n");
|
||||
return val
|
||||
.map(({ content, role }) => {
|
||||
const text = Array.isArray(content)
|
||||
? content
|
||||
.map((c) => {
|
||||
if ("text" in c) return c.text;
|
||||
if ("image_url" in c) return "(( Attached Image ))";
|
||||
})
|
||||
.join("\n")
|
||||
: content;
|
||||
return `${role}: ${text}`;
|
||||
})
|
||||
.join("\n");
|
||||
}
|
||||
return val.prompt.trim();
|
||||
};
|
||||
|
|
|
@ -26,6 +26,7 @@ import { createOnProxyResHandler, ProxyResHandlerWithBody } from "./middleware/r
|
|||
// https://platform.openai.com/docs/models/overview
|
||||
const KNOWN_MODELS = [
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-0314", // EOL 2024-06-13
|
||||
|
|
|
@ -475,7 +475,7 @@ export function registerHeartbeat(req: Request) {
|
|||
const res = req.res!;
|
||||
|
||||
const currentSize = getHeartbeatSize();
|
||||
req.log.info({
|
||||
req.log.debug({
|
||||
currentSize,
|
||||
HEARTBEAT_INTERVAL,
|
||||
PAYLOAD_SCALE_FACTOR,
|
||||
|
|
|
@ -17,8 +17,8 @@ proxyRouter.use((req, _res, next) => {
|
|||
next();
|
||||
});
|
||||
proxyRouter.use(
|
||||
express.json({ limit: "1536kb" }),
|
||||
express.urlencoded({ extended: true, limit: "1536kb" })
|
||||
express.json({ limit: "10mb" }),
|
||||
express.urlencoded({ extended: true, limit: "10mb" })
|
||||
);
|
||||
proxyRouter.use(gatekeeper);
|
||||
proxyRouter.use(checkRisuToken);
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
// We need to control the timing of when sharp is imported because it has a
|
||||
// native dependency that causes conflicts with node-canvas if they are not
|
||||
// imported in a specific order.
|
||||
import sharp from "sharp";
|
||||
|
||||
export { sharp as libSharp };
|
|
@ -3,11 +3,9 @@ import { promises as fs } from "fs";
|
|||
import path from "path";
|
||||
import { v4 } from "uuid";
|
||||
import { USER_ASSETS_DIR } from "../../config";
|
||||
import { logger } from "../../logger";
|
||||
import { addToImageHistory } from "./image-history";
|
||||
import sharp from "sharp";
|
||||
import { libSharp } from "./index";
|
||||
|
||||
const log = logger.child({ module: "file-storage" });
|
||||
|
||||
export type OpenAIImageGenerationResult = {
|
||||
created: number;
|
||||
|
@ -40,7 +38,7 @@ async function saveB64Image(b64: string) {
|
|||
async function createThumbnail(filepath: string) {
|
||||
const thumbnailPath = filepath.replace(/(\.[\wd_-]+)$/i, "_t.jpg");
|
||||
|
||||
await sharp(filepath)
|
||||
await libSharp(filepath)
|
||||
.resize(150, 150, {
|
||||
fit: "inside",
|
||||
withoutEnlargement: true,
|
||||
|
|
|
@ -27,6 +27,7 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
|||
|
||||
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
||||
"^gpt-4-1106(-preview)?$": "gpt4-turbo",
|
||||
"^gpt-4(-\\d{4})?-vision(-preview)?$": "gpt4-turbo",
|
||||
"^gpt-4-32k-\\d{4}$": "gpt4-32k",
|
||||
"^gpt-4-32k$": "gpt4-32k",
|
||||
"^gpt-4-\\d{4}$": "gpt4",
|
||||
|
|
|
@ -1,2 +1 @@
|
|||
export { OpenAIPromptMessage } from "./openai";
|
||||
export { init, countTokens } from "./tokenizer";
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
import { Tiktoken } from "tiktoken/lite";
|
||||
import cl100k_base from "tiktoken/encoders/cl100k_base.json";
|
||||
import { logger } from "../../logger";
|
||||
import { libSharp } from "../file-storage";
|
||||
import type { OpenAIChatMessage } from "../../proxy/middleware/request/transform-outbound-payload";
|
||||
|
||||
const log = logger.child({ module: "tokenizer", service: "openai" });
|
||||
const GPT4_VISION_SYSTEM_PROMPT_SIZE = 170;
|
||||
|
||||
let encoder: Tiktoken;
|
||||
|
||||
|
@ -15,8 +21,8 @@ export function init() {
|
|||
// Tested against:
|
||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
|
||||
export function getTokenCount(
|
||||
prompt: string | OpenAIPromptMessage[],
|
||||
export async function getTokenCount(
|
||||
prompt: string | OpenAIChatMessage[],
|
||||
model: string
|
||||
) {
|
||||
if (typeof prompt === "string") {
|
||||
|
@ -24,31 +30,49 @@ export function getTokenCount(
|
|||
}
|
||||
|
||||
const gpt4 = model.startsWith("gpt-4");
|
||||
const vision = model.includes("vision");
|
||||
|
||||
const tokensPerMessage = gpt4 ? 3 : 4;
|
||||
const tokensPerName = gpt4 ? 1 : -1; // turbo omits role if name is present
|
||||
|
||||
let numTokens = 0;
|
||||
let numTokens = vision ? GPT4_VISION_SYSTEM_PROMPT_SIZE : 0;
|
||||
|
||||
for (const message of prompt) {
|
||||
numTokens += tokensPerMessage;
|
||||
for (const key of Object.keys(message)) {
|
||||
{
|
||||
const value = message[key as keyof OpenAIPromptMessage];
|
||||
if (!value || typeof value !== "string") continue;
|
||||
let textContent: string = "";
|
||||
const value = message[key as keyof OpenAIChatMessage];
|
||||
|
||||
if (!value) continue;
|
||||
|
||||
if (Array.isArray(value)) {
|
||||
for (const item of value) {
|
||||
if (item.type === "text") {
|
||||
textContent += item.text;
|
||||
} else if (item.type === "image_url") {
|
||||
const { url, detail } = item.image_url;
|
||||
const cost = await getGpt4VisionTokenCost(url, detail);
|
||||
numTokens += cost ?? 0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
textContent = value;
|
||||
}
|
||||
|
||||
// Break if we get a huge message or exceed the token limit to prevent
|
||||
// DoS.
|
||||
// 100k tokens allows for future 100k GPT-4 models and 500k characters
|
||||
// 200k tokens allows for future 200k GPT-4 models and 500k characters
|
||||
// is just a sanity check
|
||||
if (value.length > 500000 || numTokens > 100000) {
|
||||
numTokens = 100000;
|
||||
if (textContent.length > 500000 || numTokens > 200000) {
|
||||
numTokens = 200000;
|
||||
return {
|
||||
tokenizer: "tiktoken (prompt length limit exceeded)",
|
||||
token_count: numTokens,
|
||||
};
|
||||
}
|
||||
|
||||
numTokens += encoder.encode(value).length;
|
||||
numTokens += encoder.encode(textContent).length;
|
||||
if (key === "name") {
|
||||
numTokens += tokensPerName;
|
||||
}
|
||||
|
@ -59,6 +83,78 @@ export function getTokenCount(
|
|||
return { tokenizer: "tiktoken", token_count: numTokens };
|
||||
}
|
||||
|
||||
async function getGpt4VisionTokenCost(
|
||||
url: string,
|
||||
detail: "auto" | "low" | "high" = "auto"
|
||||
) {
|
||||
// For now we do not allow remote images as the proxy would have to download
|
||||
// them, which is a potential DoS vector.
|
||||
if (!url.startsWith("data:image/")) {
|
||||
throw new Error(
|
||||
"Remote images are not supported. Add the image to your prompt as a base64 data URL."
|
||||
);
|
||||
}
|
||||
|
||||
const base64Data = url.split(",")[1];
|
||||
const buffer = Buffer.from(base64Data, "base64");
|
||||
const image = libSharp(buffer);
|
||||
const metadata = await image.metadata();
|
||||
|
||||
if (!metadata || !metadata.width || !metadata.height) {
|
||||
throw new Error("Prompt includes an image that could not be parsed");
|
||||
}
|
||||
|
||||
const { width, height } = metadata;
|
||||
|
||||
let selectedDetail: "low" | "high";
|
||||
if (detail === "auto") {
|
||||
const threshold = 512 * 512;
|
||||
const imageSize = width * height;
|
||||
selectedDetail = imageSize > threshold ? "high" : "low";
|
||||
} else {
|
||||
selectedDetail = detail;
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/guides/vision/calculating-costs
|
||||
if (selectedDetail === "low") {
|
||||
log.info(
|
||||
{ width, height, tokens: 85 },
|
||||
"Using fixed GPT-4-Vision token cost for low detail image"
|
||||
);
|
||||
return 85;
|
||||
}
|
||||
|
||||
let newWidth = width;
|
||||
let newHeight = height;
|
||||
if (width > 2048 || height > 2048) {
|
||||
const aspectRatio = width / height;
|
||||
if (width > height) {
|
||||
newWidth = 2048;
|
||||
newHeight = Math.round(2048 / aspectRatio);
|
||||
} else {
|
||||
newHeight = 2048;
|
||||
newWidth = Math.round(2048 * aspectRatio);
|
||||
}
|
||||
}
|
||||
|
||||
if (newWidth < newHeight) {
|
||||
newHeight = Math.round((newHeight / newWidth) * 768);
|
||||
newWidth = 768;
|
||||
} else {
|
||||
newWidth = Math.round((newWidth / newHeight) * 768);
|
||||
newHeight = 768;
|
||||
}
|
||||
|
||||
const tiles = Math.ceil(newWidth / 512) * Math.ceil(newHeight / 512);
|
||||
const tokens = 170 * tiles + 85;
|
||||
|
||||
log.info(
|
||||
{ width, height, newWidth, newHeight, tiles, tokens },
|
||||
"Calculated GPT-4-Vision token cost for high detail image"
|
||||
);
|
||||
return tokens;
|
||||
}
|
||||
|
||||
function getTextTokenCount(prompt: string) {
|
||||
if (prompt.length > 500000) {
|
||||
return {
|
||||
|
@ -73,12 +169,6 @@ function getTextTokenCount(prompt: string) {
|
|||
};
|
||||
}
|
||||
|
||||
export type OpenAIPromptMessage = {
|
||||
name?: string;
|
||||
content: string;
|
||||
role: string;
|
||||
};
|
||||
|
||||
// Model Resolution Price
|
||||
// DALL·E 3 1024×1024 $0.040 / image
|
||||
// 1024×1792, 1792×1024 $0.080 / image
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import { Request } from "express";
|
||||
import type { OpenAIChatMessage } from "../../proxy/middleware/request/transform-outbound-payload";
|
||||
import { assertNever } from "../utils";
|
||||
import {
|
||||
init as initClaude,
|
||||
|
@ -7,7 +8,6 @@ import {
|
|||
import {
|
||||
init as initOpenAi,
|
||||
getTokenCount as getOpenAITokenCount,
|
||||
OpenAIPromptMessage,
|
||||
getOpenAIImageCost,
|
||||
} from "./openai";
|
||||
import { APIFormat } from "../key-management";
|
||||
|
@ -20,7 +20,7 @@ export async function init() {
|
|||
/** Tagged union via `service` field of the different types of requests that can
|
||||
* be made to the tokenization service, for both prompts and completions */
|
||||
type TokenCountRequest = { req: Request } & (
|
||||
| { prompt: OpenAIPromptMessage[]; completion?: never; service: "openai" }
|
||||
| { prompt: OpenAIChatMessage[]; completion?: never; service: "openai" }
|
||||
| {
|
||||
prompt: string;
|
||||
completion?: never;
|
||||
|
@ -52,7 +52,7 @@ export async function countTokens({
|
|||
case "openai":
|
||||
case "openai-text":
|
||||
return {
|
||||
...getOpenAITokenCount(prompt ?? completion, req.body.model),
|
||||
...(await getOpenAITokenCount(prompt ?? completion, req.body.model)),
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
};
|
||||
case "openai-image":
|
||||
|
@ -69,7 +69,7 @@ export async function countTokens({
|
|||
// TODO: Can't find a tokenization library for PaLM. There is an API
|
||||
// endpoint for it but it adds significant latency to the request.
|
||||
return {
|
||||
...getOpenAITokenCount(prompt ?? completion, req.body.model),
|
||||
...(await getOpenAITokenCount(prompt ?? completion, req.body.model)),
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
};
|
||||
default:
|
||||
|
|
Loading…
Reference in New Issue