Support for GPT-4-Vision (khanon/oai-reverse-proxy!54)
This commit is contained in:
parent
7f2f324e26
commit
f29049f993
|
@ -1,6 +1,7 @@
|
||||||
import { RequestPreprocessor } from "./index";
|
import { RequestPreprocessor } from "./index";
|
||||||
import { countTokens, OpenAIPromptMessage } from "../../../shared/tokenization";
|
import { countTokens } from "../../../shared/tokenization";
|
||||||
import { assertNever } from "../../../shared/utils";
|
import { assertNever } from "../../../shared/utils";
|
||||||
|
import type { OpenAIChatMessage } from "./transform-outbound-payload";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Given a request with an already-transformed body, counts the number of
|
* Given a request with an already-transformed body, counts the number of
|
||||||
|
@ -13,7 +14,7 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
|
||||||
switch (service) {
|
switch (service) {
|
||||||
case "openai": {
|
case "openai": {
|
||||||
req.outputTokens = req.body.max_tokens;
|
req.outputTokens = req.body.max_tokens;
|
||||||
const prompt: OpenAIPromptMessage[] = req.body.messages;
|
const prompt: OpenAIChatMessage[] = req.body.messages;
|
||||||
result = await countTokens({ req, prompt, service });
|
result = await countTokens({ req, prompt, service });
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ import { config } from "../../../config";
|
||||||
import { assertNever } from "../../../shared/utils";
|
import { assertNever } from "../../../shared/utils";
|
||||||
import { RequestPreprocessor } from ".";
|
import { RequestPreprocessor } from ".";
|
||||||
import { UserInputError } from "../../../shared/errors";
|
import { UserInputError } from "../../../shared/errors";
|
||||||
|
import { OpenAIChatMessage } from "./transform-outbound-payload";
|
||||||
|
|
||||||
const rejectedClients = new Map<string, number>();
|
const rejectedClients = new Map<string, number>();
|
||||||
|
|
||||||
|
@ -53,9 +54,16 @@ function getPromptFromRequest(req: Request) {
|
||||||
return body.prompt;
|
return body.prompt;
|
||||||
case "openai":
|
case "openai":
|
||||||
return body.messages
|
return body.messages
|
||||||
.map(
|
.map((msg: OpenAIChatMessage) => {
|
||||||
(m: { content: string; role: string }) => `${m.role}: ${m.content}`
|
const text = Array.isArray(msg.content)
|
||||||
)
|
? msg.content
|
||||||
|
.map((c) => {
|
||||||
|
if ("text" in c) return c.text;
|
||||||
|
})
|
||||||
|
.join()
|
||||||
|
: msg.content;
|
||||||
|
return `${msg.role}: ${text}`;
|
||||||
|
})
|
||||||
.join("\n\n");
|
.join("\n\n");
|
||||||
case "openai-text":
|
case "openai-text":
|
||||||
case "openai-image":
|
case "openai-image":
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import { Request } from "express";
|
import { Request } from "express";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { config } from "../../../config";
|
import { config } from "../../../config";
|
||||||
import { OpenAIPromptMessage } from "../../../shared/tokenization";
|
|
||||||
import { isTextGenerationRequest, isImageGenerationRequest } from "../common";
|
import { isTextGenerationRequest, isImageGenerationRequest } from "../common";
|
||||||
import { RequestPreprocessor } from ".";
|
import { RequestPreprocessor } from ".";
|
||||||
import { APIFormat } from "../../../shared/key-management";
|
import { APIFormat } from "../../../shared/key-management";
|
||||||
|
@ -9,6 +8,8 @@ import { APIFormat } from "../../../shared/key-management";
|
||||||
const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic;
|
const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic;
|
||||||
const OPENAI_OUTPUT_MAX = config.maxOutputTokensOpenAI;
|
const OPENAI_OUTPUT_MAX = config.maxOutputTokensOpenAI;
|
||||||
|
|
||||||
|
// TODO: move schemas to shared
|
||||||
|
|
||||||
// https://console.anthropic.com/docs/api/reference#-v1-complete
|
// https://console.anthropic.com/docs/api/reference#-v1-complete
|
||||||
export const AnthropicV1CompleteSchema = z.object({
|
export const AnthropicV1CompleteSchema = z.object({
|
||||||
model: z.string(),
|
model: z.string(),
|
||||||
|
@ -29,12 +30,25 @@ export const AnthropicV1CompleteSchema = z.object({
|
||||||
});
|
});
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/chat/create
|
// https://platform.openai.com/docs/api-reference/chat/create
|
||||||
const OpenAIV1ChatCompletionSchema = z.object({
|
const OpenAIV1ChatContentArraySchema = z.array(
|
||||||
|
z.union([
|
||||||
|
z.object({ type: z.literal("text"), text: z.string() }),
|
||||||
|
z.object({
|
||||||
|
type: z.literal("image_url"),
|
||||||
|
image_url: z.object({
|
||||||
|
url: z.string().url(),
|
||||||
|
detail: z.enum(["low", "auto", "high"]).optional().default("auto"),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
);
|
||||||
|
|
||||||
|
export const OpenAIV1ChatCompletionSchema = z.object({
|
||||||
model: z.string(),
|
model: z.string(),
|
||||||
messages: z.array(
|
messages: z.array(
|
||||||
z.object({
|
z.object({
|
||||||
role: z.enum(["system", "user", "assistant"]),
|
role: z.enum(["system", "user", "assistant"]),
|
||||||
content: z.string(),
|
content: z.union([z.string(), OpenAIV1ChatContentArraySchema]),
|
||||||
name: z.string().optional(),
|
name: z.string().optional(),
|
||||||
}),
|
}),
|
||||||
{
|
{
|
||||||
|
@ -68,6 +82,10 @@ const OpenAIV1ChatCompletionSchema = z.object({
|
||||||
seed: z.number().int().optional(),
|
seed: z.number().int().optional(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
export type OpenAIChatMessage = z.infer<
|
||||||
|
typeof OpenAIV1ChatCompletionSchema
|
||||||
|
>["messages"][0];
|
||||||
|
|
||||||
const OpenAIV1TextCompletionSchema = z
|
const OpenAIV1TextCompletionSchema = z
|
||||||
.object({
|
.object({
|
||||||
model: z
|
model: z
|
||||||
|
@ -232,7 +250,7 @@ function openaiToOpenaiText(req: Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const { messages, ...rest } = result.data;
|
const { messages, ...rest } = result.data;
|
||||||
const prompt = flattenOpenAiChatMessages(messages);
|
const prompt = flattenOpenAIChatMessages(messages);
|
||||||
|
|
||||||
let stops = rest.stop
|
let stops = rest.stop
|
||||||
? Array.isArray(rest.stop)
|
? Array.isArray(rest.stop)
|
||||||
|
@ -260,6 +278,9 @@ function openaiToOpenaiImage(req: Request) {
|
||||||
|
|
||||||
const { messages } = result.data;
|
const { messages } = result.data;
|
||||||
const prompt = messages.filter((m) => m.role === "user").pop()?.content;
|
const prompt = messages.filter((m) => m.role === "user").pop()?.content;
|
||||||
|
if (Array.isArray(prompt)) {
|
||||||
|
throw new Error("Image generation prompt must be a text message.");
|
||||||
|
}
|
||||||
|
|
||||||
if (body.stream) {
|
if (body.stream) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
|
@ -304,7 +325,7 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
|
||||||
}
|
}
|
||||||
|
|
||||||
const { messages, ...rest } = result.data;
|
const { messages, ...rest } = result.data;
|
||||||
const prompt = flattenOpenAiChatMessages(messages);
|
const prompt = flattenOpenAIChatMessages(messages);
|
||||||
|
|
||||||
let stops = rest.stop
|
let stops = rest.stop
|
||||||
? Array.isArray(rest.stop)
|
? Array.isArray(rest.stop)
|
||||||
|
@ -336,7 +357,7 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
export function openAIMessagesToClaudePrompt(messages: OpenAIPromptMessage[]) {
|
export function openAIMessagesToClaudePrompt(messages: OpenAIChatMessage[]) {
|
||||||
return (
|
return (
|
||||||
messages
|
messages
|
||||||
.map((m) => {
|
.map((m) => {
|
||||||
|
@ -348,17 +369,17 @@ export function openAIMessagesToClaudePrompt(messages: OpenAIPromptMessage[]) {
|
||||||
} else if (role === "user") {
|
} else if (role === "user") {
|
||||||
role = "Human";
|
role = "Human";
|
||||||
}
|
}
|
||||||
|
const name = m.name?.trim();
|
||||||
|
const content = flattenOpenAIMessageContent(m.content);
|
||||||
// https://console.anthropic.com/docs/prompt-design
|
// https://console.anthropic.com/docs/prompt-design
|
||||||
// `name` isn't supported by Anthropic but we can still try to use it.
|
// `name` isn't supported by Anthropic but we can still try to use it.
|
||||||
return `\n\n${role}: ${m.name?.trim() ? `(as ${m.name}) ` : ""}${
|
return `\n\n${role}: ${name ? `(as ${name}) ` : ""}${content}`;
|
||||||
m.content
|
|
||||||
}`;
|
|
||||||
})
|
})
|
||||||
.join("") + "\n\nAssistant:"
|
.join("") + "\n\nAssistant:"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
function flattenOpenAiChatMessages(messages: OpenAIPromptMessage[]) {
|
function flattenOpenAIChatMessages(messages: OpenAIChatMessage[]) {
|
||||||
// Temporary to allow experimenting with prompt strategies
|
// Temporary to allow experimenting with prompt strategies
|
||||||
const PROMPT_VERSION: number = 1;
|
const PROMPT_VERSION: number = 1;
|
||||||
switch (PROMPT_VERSION) {
|
switch (PROMPT_VERSION) {
|
||||||
|
@ -375,7 +396,7 @@ function flattenOpenAiChatMessages(messages: OpenAIPromptMessage[]) {
|
||||||
} else if (role === "user") {
|
} else if (role === "user") {
|
||||||
role = "User";
|
role = "User";
|
||||||
}
|
}
|
||||||
return `\n\n${role}: ${m.content}`;
|
return `\n\n${role}: ${flattenOpenAIMessageContent(m.content)}`;
|
||||||
})
|
})
|
||||||
.join("") + "\n\nAssistant:"
|
.join("") + "\n\nAssistant:"
|
||||||
);
|
);
|
||||||
|
@ -387,10 +408,23 @@ function flattenOpenAiChatMessages(messages: OpenAIPromptMessage[]) {
|
||||||
if (role === "system") {
|
if (role === "system") {
|
||||||
role = "System: ";
|
role = "System: ";
|
||||||
}
|
}
|
||||||
return `\n\n${role}${m.content}`;
|
return `\n\n${role}${flattenOpenAIMessageContent(m.content)}`;
|
||||||
})
|
})
|
||||||
.join("");
|
.join("");
|
||||||
default:
|
default:
|
||||||
throw new Error(`Unknown prompt version: ${PROMPT_VERSION}`);
|
throw new Error(`Unknown prompt version: ${PROMPT_VERSION}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function flattenOpenAIMessageContent(
|
||||||
|
content: OpenAIChatMessage["content"]
|
||||||
|
): string {
|
||||||
|
return Array.isArray(content)
|
||||||
|
? content
|
||||||
|
.map((contentItem) => {
|
||||||
|
if ("text" in contentItem) return contentItem.text;
|
||||||
|
if ("image_url" in contentItem) return "[ Uploaded Image Omitted ]";
|
||||||
|
})
|
||||||
|
.join("\n")
|
||||||
|
: content;
|
||||||
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import {
|
||||||
} from "../common";
|
} from "../common";
|
||||||
import { ProxyResHandlerWithBody } from ".";
|
import { ProxyResHandlerWithBody } from ".";
|
||||||
import { assertNever } from "../../../shared/utils";
|
import { assertNever } from "../../../shared/utils";
|
||||||
|
import { OpenAIChatMessage } from "../request/transform-outbound-payload";
|
||||||
|
|
||||||
/** If prompt logging is enabled, enqueues the prompt for logging. */
|
/** If prompt logging is enabled, enqueues the prompt for logging. */
|
||||||
export const logPrompt: ProxyResHandlerWithBody = async (
|
export const logPrompt: ProxyResHandlerWithBody = async (
|
||||||
|
@ -42,11 +43,6 @@ export const logPrompt: ProxyResHandlerWithBody = async (
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
type OaiMessage = {
|
|
||||||
role: "user" | "assistant" | "system";
|
|
||||||
content: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
type OaiImageResult = {
|
type OaiImageResult = {
|
||||||
prompt: string;
|
prompt: string;
|
||||||
size: string;
|
size: string;
|
||||||
|
@ -58,7 +54,7 @@ type OaiImageResult = {
|
||||||
const getPromptForRequest = (
|
const getPromptForRequest = (
|
||||||
req: Request,
|
req: Request,
|
||||||
responseBody: Record<string, any>
|
responseBody: Record<string, any>
|
||||||
): string | OaiMessage[] | OaiImageResult => {
|
): string | OpenAIChatMessage[] | OaiImageResult => {
|
||||||
// Since the prompt logger only runs after the request has been proxied, we
|
// Since the prompt logger only runs after the request has been proxied, we
|
||||||
// can assume the body has already been transformed to the target API's
|
// can assume the body has already been transformed to the target API's
|
||||||
// format.
|
// format.
|
||||||
|
@ -85,13 +81,25 @@ const getPromptForRequest = (
|
||||||
};
|
};
|
||||||
|
|
||||||
const flattenMessages = (
|
const flattenMessages = (
|
||||||
val: string | OaiMessage[] | OaiImageResult
|
val: string | OpenAIChatMessage[] | OaiImageResult
|
||||||
): string => {
|
): string => {
|
||||||
if (typeof val === "string") {
|
if (typeof val === "string") {
|
||||||
return val.trim();
|
return val.trim();
|
||||||
}
|
}
|
||||||
if (Array.isArray(val)) {
|
if (Array.isArray(val)) {
|
||||||
return val.map((m) => `${m.role}: ${m.content}`).join("\n");
|
return val
|
||||||
|
.map(({ content, role }) => {
|
||||||
|
const text = Array.isArray(content)
|
||||||
|
? content
|
||||||
|
.map((c) => {
|
||||||
|
if ("text" in c) return c.text;
|
||||||
|
if ("image_url" in c) return "(( Attached Image ))";
|
||||||
|
})
|
||||||
|
.join("\n")
|
||||||
|
: content;
|
||||||
|
return `${role}: ${text}`;
|
||||||
|
})
|
||||||
|
.join("\n");
|
||||||
}
|
}
|
||||||
return val.prompt.trim();
|
return val.prompt.trim();
|
||||||
};
|
};
|
||||||
|
|
|
@ -26,6 +26,7 @@ import { createOnProxyResHandler, ProxyResHandlerWithBody } from "./middleware/r
|
||||||
// https://platform.openai.com/docs/models/overview
|
// https://platform.openai.com/docs/models/overview
|
||||||
const KNOWN_MODELS = [
|
const KNOWN_MODELS = [
|
||||||
"gpt-4-1106-preview",
|
"gpt-4-1106-preview",
|
||||||
|
"gpt-4-vision-preview",
|
||||||
"gpt-4",
|
"gpt-4",
|
||||||
"gpt-4-0613",
|
"gpt-4-0613",
|
||||||
"gpt-4-0314", // EOL 2024-06-13
|
"gpt-4-0314", // EOL 2024-06-13
|
||||||
|
|
|
@ -475,7 +475,7 @@ export function registerHeartbeat(req: Request) {
|
||||||
const res = req.res!;
|
const res = req.res!;
|
||||||
|
|
||||||
const currentSize = getHeartbeatSize();
|
const currentSize = getHeartbeatSize();
|
||||||
req.log.info({
|
req.log.debug({
|
||||||
currentSize,
|
currentSize,
|
||||||
HEARTBEAT_INTERVAL,
|
HEARTBEAT_INTERVAL,
|
||||||
PAYLOAD_SCALE_FACTOR,
|
PAYLOAD_SCALE_FACTOR,
|
||||||
|
|
|
@ -17,8 +17,8 @@ proxyRouter.use((req, _res, next) => {
|
||||||
next();
|
next();
|
||||||
});
|
});
|
||||||
proxyRouter.use(
|
proxyRouter.use(
|
||||||
express.json({ limit: "1536kb" }),
|
express.json({ limit: "10mb" }),
|
||||||
express.urlencoded({ extended: true, limit: "1536kb" })
|
express.urlencoded({ extended: true, limit: "10mb" })
|
||||||
);
|
);
|
||||||
proxyRouter.use(gatekeeper);
|
proxyRouter.use(gatekeeper);
|
||||||
proxyRouter.use(checkRisuToken);
|
proxyRouter.use(checkRisuToken);
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
// We need to control the timing of when sharp is imported because it has a
|
||||||
|
// native dependency that causes conflicts with node-canvas if they are not
|
||||||
|
// imported in a specific order.
|
||||||
|
import sharp from "sharp";
|
||||||
|
|
||||||
|
export { sharp as libSharp };
|
|
@ -3,11 +3,9 @@ import { promises as fs } from "fs";
|
||||||
import path from "path";
|
import path from "path";
|
||||||
import { v4 } from "uuid";
|
import { v4 } from "uuid";
|
||||||
import { USER_ASSETS_DIR } from "../../config";
|
import { USER_ASSETS_DIR } from "../../config";
|
||||||
import { logger } from "../../logger";
|
|
||||||
import { addToImageHistory } from "./image-history";
|
import { addToImageHistory } from "./image-history";
|
||||||
import sharp from "sharp";
|
import { libSharp } from "./index";
|
||||||
|
|
||||||
const log = logger.child({ module: "file-storage" });
|
|
||||||
|
|
||||||
export type OpenAIImageGenerationResult = {
|
export type OpenAIImageGenerationResult = {
|
||||||
created: number;
|
created: number;
|
||||||
|
@ -40,7 +38,7 @@ async function saveB64Image(b64: string) {
|
||||||
async function createThumbnail(filepath: string) {
|
async function createThumbnail(filepath: string) {
|
||||||
const thumbnailPath = filepath.replace(/(\.[\wd_-]+)$/i, "_t.jpg");
|
const thumbnailPath = filepath.replace(/(\.[\wd_-]+)$/i, "_t.jpg");
|
||||||
|
|
||||||
await sharp(filepath)
|
await libSharp(filepath)
|
||||||
.resize(150, 150, {
|
.resize(150, 150, {
|
||||||
fit: "inside",
|
fit: "inside",
|
||||||
withoutEnlargement: true,
|
withoutEnlargement: true,
|
||||||
|
|
|
@ -27,6 +27,7 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||||
|
|
||||||
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
||||||
"^gpt-4-1106(-preview)?$": "gpt4-turbo",
|
"^gpt-4-1106(-preview)?$": "gpt4-turbo",
|
||||||
|
"^gpt-4(-\\d{4})?-vision(-preview)?$": "gpt4-turbo",
|
||||||
"^gpt-4-32k-\\d{4}$": "gpt4-32k",
|
"^gpt-4-32k-\\d{4}$": "gpt4-32k",
|
||||||
"^gpt-4-32k$": "gpt4-32k",
|
"^gpt-4-32k$": "gpt4-32k",
|
||||||
"^gpt-4-\\d{4}$": "gpt4",
|
"^gpt-4-\\d{4}$": "gpt4",
|
||||||
|
|
|
@ -1,2 +1 @@
|
||||||
export { OpenAIPromptMessage } from "./openai";
|
|
||||||
export { init, countTokens } from "./tokenizer";
|
export { init, countTokens } from "./tokenizer";
|
||||||
|
|
|
@ -1,5 +1,11 @@
|
||||||
import { Tiktoken } from "tiktoken/lite";
|
import { Tiktoken } from "tiktoken/lite";
|
||||||
import cl100k_base from "tiktoken/encoders/cl100k_base.json";
|
import cl100k_base from "tiktoken/encoders/cl100k_base.json";
|
||||||
|
import { logger } from "../../logger";
|
||||||
|
import { libSharp } from "../file-storage";
|
||||||
|
import type { OpenAIChatMessage } from "../../proxy/middleware/request/transform-outbound-payload";
|
||||||
|
|
||||||
|
const log = logger.child({ module: "tokenizer", service: "openai" });
|
||||||
|
const GPT4_VISION_SYSTEM_PROMPT_SIZE = 170;
|
||||||
|
|
||||||
let encoder: Tiktoken;
|
let encoder: Tiktoken;
|
||||||
|
|
||||||
|
@ -15,8 +21,8 @@ export function init() {
|
||||||
// Tested against:
|
// Tested against:
|
||||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||||
|
|
||||||
export function getTokenCount(
|
export async function getTokenCount(
|
||||||
prompt: string | OpenAIPromptMessage[],
|
prompt: string | OpenAIChatMessage[],
|
||||||
model: string
|
model: string
|
||||||
) {
|
) {
|
||||||
if (typeof prompt === "string") {
|
if (typeof prompt === "string") {
|
||||||
|
@ -24,31 +30,49 @@ export function getTokenCount(
|
||||||
}
|
}
|
||||||
|
|
||||||
const gpt4 = model.startsWith("gpt-4");
|
const gpt4 = model.startsWith("gpt-4");
|
||||||
|
const vision = model.includes("vision");
|
||||||
|
|
||||||
const tokensPerMessage = gpt4 ? 3 : 4;
|
const tokensPerMessage = gpt4 ? 3 : 4;
|
||||||
const tokensPerName = gpt4 ? 1 : -1; // turbo omits role if name is present
|
const tokensPerName = gpt4 ? 1 : -1; // turbo omits role if name is present
|
||||||
|
|
||||||
let numTokens = 0;
|
let numTokens = vision ? GPT4_VISION_SYSTEM_PROMPT_SIZE : 0;
|
||||||
|
|
||||||
for (const message of prompt) {
|
for (const message of prompt) {
|
||||||
numTokens += tokensPerMessage;
|
numTokens += tokensPerMessage;
|
||||||
for (const key of Object.keys(message)) {
|
for (const key of Object.keys(message)) {
|
||||||
{
|
{
|
||||||
const value = message[key as keyof OpenAIPromptMessage];
|
let textContent: string = "";
|
||||||
if (!value || typeof value !== "string") continue;
|
const value = message[key as keyof OpenAIChatMessage];
|
||||||
|
|
||||||
|
if (!value) continue;
|
||||||
|
|
||||||
|
if (Array.isArray(value)) {
|
||||||
|
for (const item of value) {
|
||||||
|
if (item.type === "text") {
|
||||||
|
textContent += item.text;
|
||||||
|
} else if (item.type === "image_url") {
|
||||||
|
const { url, detail } = item.image_url;
|
||||||
|
const cost = await getGpt4VisionTokenCost(url, detail);
|
||||||
|
numTokens += cost ?? 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
textContent = value;
|
||||||
|
}
|
||||||
|
|
||||||
// Break if we get a huge message or exceed the token limit to prevent
|
// Break if we get a huge message or exceed the token limit to prevent
|
||||||
// DoS.
|
// DoS.
|
||||||
// 100k tokens allows for future 100k GPT-4 models and 500k characters
|
// 200k tokens allows for future 200k GPT-4 models and 500k characters
|
||||||
// is just a sanity check
|
// is just a sanity check
|
||||||
if (value.length > 500000 || numTokens > 100000) {
|
if (textContent.length > 500000 || numTokens > 200000) {
|
||||||
numTokens = 100000;
|
numTokens = 200000;
|
||||||
return {
|
return {
|
||||||
tokenizer: "tiktoken (prompt length limit exceeded)",
|
tokenizer: "tiktoken (prompt length limit exceeded)",
|
||||||
token_count: numTokens,
|
token_count: numTokens,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
numTokens += encoder.encode(value).length;
|
numTokens += encoder.encode(textContent).length;
|
||||||
if (key === "name") {
|
if (key === "name") {
|
||||||
numTokens += tokensPerName;
|
numTokens += tokensPerName;
|
||||||
}
|
}
|
||||||
|
@ -59,6 +83,78 @@ export function getTokenCount(
|
||||||
return { tokenizer: "tiktoken", token_count: numTokens };
|
return { tokenizer: "tiktoken", token_count: numTokens };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function getGpt4VisionTokenCost(
|
||||||
|
url: string,
|
||||||
|
detail: "auto" | "low" | "high" = "auto"
|
||||||
|
) {
|
||||||
|
// For now we do not allow remote images as the proxy would have to download
|
||||||
|
// them, which is a potential DoS vector.
|
||||||
|
if (!url.startsWith("data:image/")) {
|
||||||
|
throw new Error(
|
||||||
|
"Remote images are not supported. Add the image to your prompt as a base64 data URL."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const base64Data = url.split(",")[1];
|
||||||
|
const buffer = Buffer.from(base64Data, "base64");
|
||||||
|
const image = libSharp(buffer);
|
||||||
|
const metadata = await image.metadata();
|
||||||
|
|
||||||
|
if (!metadata || !metadata.width || !metadata.height) {
|
||||||
|
throw new Error("Prompt includes an image that could not be parsed");
|
||||||
|
}
|
||||||
|
|
||||||
|
const { width, height } = metadata;
|
||||||
|
|
||||||
|
let selectedDetail: "low" | "high";
|
||||||
|
if (detail === "auto") {
|
||||||
|
const threshold = 512 * 512;
|
||||||
|
const imageSize = width * height;
|
||||||
|
selectedDetail = imageSize > threshold ? "high" : "low";
|
||||||
|
} else {
|
||||||
|
selectedDetail = detail;
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://platform.openai.com/docs/guides/vision/calculating-costs
|
||||||
|
if (selectedDetail === "low") {
|
||||||
|
log.info(
|
||||||
|
{ width, height, tokens: 85 },
|
||||||
|
"Using fixed GPT-4-Vision token cost for low detail image"
|
||||||
|
);
|
||||||
|
return 85;
|
||||||
|
}
|
||||||
|
|
||||||
|
let newWidth = width;
|
||||||
|
let newHeight = height;
|
||||||
|
if (width > 2048 || height > 2048) {
|
||||||
|
const aspectRatio = width / height;
|
||||||
|
if (width > height) {
|
||||||
|
newWidth = 2048;
|
||||||
|
newHeight = Math.round(2048 / aspectRatio);
|
||||||
|
} else {
|
||||||
|
newHeight = 2048;
|
||||||
|
newWidth = Math.round(2048 * aspectRatio);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (newWidth < newHeight) {
|
||||||
|
newHeight = Math.round((newHeight / newWidth) * 768);
|
||||||
|
newWidth = 768;
|
||||||
|
} else {
|
||||||
|
newWidth = Math.round((newWidth / newHeight) * 768);
|
||||||
|
newHeight = 768;
|
||||||
|
}
|
||||||
|
|
||||||
|
const tiles = Math.ceil(newWidth / 512) * Math.ceil(newHeight / 512);
|
||||||
|
const tokens = 170 * tiles + 85;
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
{ width, height, newWidth, newHeight, tiles, tokens },
|
||||||
|
"Calculated GPT-4-Vision token cost for high detail image"
|
||||||
|
);
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
function getTextTokenCount(prompt: string) {
|
function getTextTokenCount(prompt: string) {
|
||||||
if (prompt.length > 500000) {
|
if (prompt.length > 500000) {
|
||||||
return {
|
return {
|
||||||
|
@ -73,12 +169,6 @@ function getTextTokenCount(prompt: string) {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
export type OpenAIPromptMessage = {
|
|
||||||
name?: string;
|
|
||||||
content: string;
|
|
||||||
role: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Model Resolution Price
|
// Model Resolution Price
|
||||||
// DALL·E 3 1024×1024 $0.040 / image
|
// DALL·E 3 1024×1024 $0.040 / image
|
||||||
// 1024×1792, 1792×1024 $0.080 / image
|
// 1024×1792, 1792×1024 $0.080 / image
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import { Request } from "express";
|
import { Request } from "express";
|
||||||
|
import type { OpenAIChatMessage } from "../../proxy/middleware/request/transform-outbound-payload";
|
||||||
import { assertNever } from "../utils";
|
import { assertNever } from "../utils";
|
||||||
import {
|
import {
|
||||||
init as initClaude,
|
init as initClaude,
|
||||||
|
@ -7,7 +8,6 @@ import {
|
||||||
import {
|
import {
|
||||||
init as initOpenAi,
|
init as initOpenAi,
|
||||||
getTokenCount as getOpenAITokenCount,
|
getTokenCount as getOpenAITokenCount,
|
||||||
OpenAIPromptMessage,
|
|
||||||
getOpenAIImageCost,
|
getOpenAIImageCost,
|
||||||
} from "./openai";
|
} from "./openai";
|
||||||
import { APIFormat } from "../key-management";
|
import { APIFormat } from "../key-management";
|
||||||
|
@ -20,7 +20,7 @@ export async function init() {
|
||||||
/** Tagged union via `service` field of the different types of requests that can
|
/** Tagged union via `service` field of the different types of requests that can
|
||||||
* be made to the tokenization service, for both prompts and completions */
|
* be made to the tokenization service, for both prompts and completions */
|
||||||
type TokenCountRequest = { req: Request } & (
|
type TokenCountRequest = { req: Request } & (
|
||||||
| { prompt: OpenAIPromptMessage[]; completion?: never; service: "openai" }
|
| { prompt: OpenAIChatMessage[]; completion?: never; service: "openai" }
|
||||||
| {
|
| {
|
||||||
prompt: string;
|
prompt: string;
|
||||||
completion?: never;
|
completion?: never;
|
||||||
|
@ -52,7 +52,7 @@ export async function countTokens({
|
||||||
case "openai":
|
case "openai":
|
||||||
case "openai-text":
|
case "openai-text":
|
||||||
return {
|
return {
|
||||||
...getOpenAITokenCount(prompt ?? completion, req.body.model),
|
...(await getOpenAITokenCount(prompt ?? completion, req.body.model)),
|
||||||
tokenization_duration_ms: getElapsedMs(time),
|
tokenization_duration_ms: getElapsedMs(time),
|
||||||
};
|
};
|
||||||
case "openai-image":
|
case "openai-image":
|
||||||
|
@ -69,7 +69,7 @@ export async function countTokens({
|
||||||
// TODO: Can't find a tokenization library for PaLM. There is an API
|
// TODO: Can't find a tokenization library for PaLM. There is an API
|
||||||
// endpoint for it but it adds significant latency to the request.
|
// endpoint for it but it adds significant latency to the request.
|
||||||
return {
|
return {
|
||||||
...getOpenAITokenCount(prompt ?? completion, req.body.model),
|
...(await getOpenAITokenCount(prompt ?? completion, req.body.model)),
|
||||||
tokenization_duration_ms: getElapsedMs(time),
|
tokenization_duration_ms: getElapsedMs(time),
|
||||||
};
|
};
|
||||||
default:
|
default:
|
||||||
|
|
Loading…
Reference in New Issue