adds Claude 3 Vision support

This commit is contained in:
nai-degen 2024-03-05 18:34:10 -06:00
parent ea3aae5da6
commit ddf34685df
5 changed files with 102 additions and 41 deletions

View File

@ -10,7 +10,8 @@ import {
import { ProxyResHandlerWithBody } from ".";
import { assertNever } from "../../../shared/utils";
import {
AnthropicChatMessage, flattenAnthropicMessages,
AnthropicChatMessage,
flattenAnthropicMessages,
MistralAIChatMessage,
OpenAIChatMessage,
} from "../../../shared/api-schemas";
@ -95,11 +96,11 @@ const getPromptForRequest = (
const flattenMessages = (
val:
| string
| OpenAIChatMessage[]
| MistralAIChatMessage[]
| OaiImageResult
| AnthropicChatMessage[],
format: APIFormat,
| OpenAIChatMessage[]
| AnthropicChatMessage[]
| MistralAIChatMessage[],
format: APIFormat
): string => {
if (typeof val === "string") {
return val.trim();
@ -115,6 +116,8 @@ const flattenMessages = (
.map((c) => {
if ("text" in c) return c.text;
if ("image_url" in c) return "(( Attached Image ))";
if ("source" in c) return "(( Attached Image ))";
return "(( Unsupported Content ))";
})
.join("\n")
: content;

View File

@ -6,7 +6,6 @@ import {
OpenAIChatMessage,
OpenAIV1ChatCompletionSchema,
} from "./openai";
import { logger } from "../../logger";
const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic;
@ -33,23 +32,32 @@ export const AnthropicV1TextSchema = AnthropicV1BaseSchema.merge(
})
);
const AnthropicV1MessageMultimodalContentSchema = z.array(
z.union([
z.object({ type: z.literal("text"), text: z.string() }),
z.object({
type: z.literal("image"),
source: z.object({
type: z.literal("base64"),
media_type: z.string().max(100),
data: z.string(),
}),
}),
])
);
// https://docs.anthropic.com/claude/reference/messages_post
export const AnthropicV1MessagesSchema = AnthropicV1BaseSchema.merge(
z.object({
messages: z
.array(
z.object({
role: z.enum(["user", "assistant"]),
content: z.union([
z.string(),
z.array(z.object({ type: z.string().max(100), text: z.string() })),
]),
})
)
.min(1)
.refine((v) => v[0].role === "user", {
message: `First message must be have 'user' role. Use 'system' parameter to start with a system message.`,
}),
messages: z.array(
z.object({
role: z.enum(["user", "assistant"]),
content: z.union([
z.string(),
AnthropicV1MessageMultimodalContentSchema,
]),
})
),
max_tokens: z
.number()
.int()
@ -219,8 +227,10 @@ export function flattenAnthropicMessages(
? msg.content
: [{ type: "text", text: msg.content }];
return `${name}: ${parts
.map(({ text, type }) =>
type === "text" ? text : `[Unsupported content type: ${type}]`
.map((part) =>
part.type === "text"
? part.text
: `[Omitted multimodal content of type ${part.type}]`
)
.join("\n")}`;
})

View File

@ -7,7 +7,7 @@ import { KeyCheckerBase } from "../key-checker-base";
import type { AwsBedrockKey, AwsBedrockKeyProvider } from "./provider";
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 3 * 60 * 1000; // 3 minutes
const KEY_CHECK_PERIOD = 30 * 60 * 1000; // 30 minutes
const AMZ_HOST =
process.env.AMZ_HOST || "bedrock-runtime.%REGION%.amazonaws.com";
const GET_CALLER_IDENTITY_URL = `https://sts.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15`;

View File

@ -1,6 +1,10 @@
import { getTokenizer } from "@anthropic-ai/tokenizer";
import { Tiktoken } from "tiktoken/lite";
import { AnthropicChatMessage } from "../api-schemas";
import { libSharp } from "../file-storage";
import { logger } from "../../logger";
const log = logger.child({ module: "tokenizer", service: "anthropic" });
let encoder: Tiktoken;
let userRoleCount = 0;
@ -15,7 +19,7 @@ export function init() {
return true;
}
export function getTokenCount(prompt: string | AnthropicChatMessage[]) {
export async function getTokenCount(prompt: string | AnthropicChatMessage[]) {
if (typeof prompt !== "string") {
return getTokenCountForMessages(prompt);
}
@ -30,7 +34,7 @@ export function getTokenCount(prompt: string | AnthropicChatMessage[]) {
};
}
function getTokenCountForMessages(messages: AnthropicChatMessage[]) {
async function getTokenCountForMessages(messages: AnthropicChatMessage[]) {
let numTokens = 0;
for (const message of messages) {
@ -39,20 +43,23 @@ function getTokenCountForMessages(messages: AnthropicChatMessage[]) {
const parts = Array.isArray(content)
? content
: [{ type: "text", text: content }];
: [{ type: "text" as const, text: content }];
for (const part of parts) {
// We don't allow other content types for now because we can't estimate
// cost for them.
if (part.type !== "text") {
throw new Error(`Unsupported Anthropic content type: ${part.type}`);
switch (part.type) {
case "text":
const { text } = part;
if (text.length > 800000 || numTokens > 200000) {
throw new Error("Text content is too large to tokenize.");
}
numTokens += encoder.encode(text.normalize("NFKC"), "all").length;
break;
case "image":
numTokens += await getImageTokenCount(part.source.data);
break;
default:
throw new Error(`Unsupported Anthropic content type.`);
}
if (part.text.length > 800000 || numTokens > 200000) {
throw new Error("Content is too large to tokenize.");
}
numTokens += encoder.encode(part.text.normalize("NFKC"), "all").length;
}
}
@ -62,3 +69,48 @@ function getTokenCountForMessages(messages: AnthropicChatMessage[]) {
return { tokenizer: "@anthropic-ai/tokenizer", token_count: numTokens };
}
async function getImageTokenCount(b64: string) {
// https://docs.anthropic.com/claude/docs/vision
// If your image's long edge is more than 1568 pixels, or your image is more
// than ~1600 tokens, it will first be scaled down, preserving aspect ratio,
// until it is within size limits. Assuming your image does not need to be
// resized, you can estimate the number of tokens used via this simple
// algorithm:
// tokens = (width px * height px)/750
const buffer = Buffer.from(b64, "base64");
const image = libSharp(buffer);
const metadata = await image.metadata();
if (!metadata || !metadata.width || !metadata.height) {
throw new Error("Prompt includes an image that could not be parsed");
}
const MAX_TOKENS = 1600;
const MAX_LENGTH_PX = 1568;
const PIXELS_PER_TOKEN = 750;
const { width, height } = metadata;
let tokens = (width * height) / PIXELS_PER_TOKEN;
// Resize the image if it's too large
if (tokens > MAX_TOKENS || width > MAX_LENGTH_PX || height > MAX_LENGTH_PX) {
const longestEdge = Math.max(width, height);
let factor;
if (tokens > MAX_TOKENS) {
const targetPixels = PIXELS_PER_TOKEN * MAX_TOKENS;
factor = Math.sqrt(targetPixels / (width * height));
} else {
factor = MAX_LENGTH_PX / longestEdge;
}
const scaledWidth = width * factor;
const scaledHeight = height * factor;
tokens = (scaledWidth * scaledHeight) / 750;
}
log.debug({ width, height, tokens }, "Calculated Claude Vision token cost");
return Math.ceil(tokens);
}

View File

@ -99,13 +99,9 @@ export async function countTokens({
const time = process.hrtime();
switch (service) {
case "anthropic-chat":
return {
...getClaudeTokenCount(prompt ?? completion),
tokenization_duration_ms: getElapsedMs(time),
};
case "anthropic-text":
return {
...getClaudeTokenCount(prompt ?? completion),
...(await getClaudeTokenCount(prompt ?? completion)),
tokenization_duration_ms: getElapsedMs(time),
};
case "openai":