This commit is contained in:
nai-degen 2024-03-16 00:04:27 -05:00
parent d9117bf08e
commit 84acc429d7
38 changed files with 635 additions and 596 deletions

View File

@ -6,7 +6,7 @@ import {
GoogleAIChatMessage,
MistralAIChatMessage,
OpenAIChatMessage,
} from "../../../../shared/api-schemas";
} from "../../../../shared/api-support";
/**
* Given a request with an already-transformed body, counts the number of

View File

@ -7,7 +7,7 @@ import {
MistralAIChatMessage,
OpenAIChatMessage,
flattenAnthropicMessages,
} from "../../../../shared/api-schemas";
} from "../../../../shared/api-support";
const rejectedClients = new Map<string, number>();

View File

@ -5,7 +5,7 @@ import { HttpRequest } from "@smithy/protocol-http";
import {
AnthropicV1TextSchema,
AnthropicV1MessagesSchema,
} from "../../../../shared/api-schemas";
} from "../../../../shared/api-support";
import { keyPool } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index";

View File

@ -1,14 +1,14 @@
import {
API_REQUEST_VALIDATORS,
API_REQUEST_TRANSFORMERS,
} from "../../../../shared/api-schemas";
} from "../../../../shared/api-support";
import { BadRequestError } from "../../../../shared/errors";
import { fixMistralPrompt } from "../../../../shared/api-schemas/mistral-ai";
import {
isImageGenerationRequest,
isTextGenerationRequest,
} from "../../common";
import { RequestPreprocessor } from "../index";
import { fixMistralPrompt } from "../../../../shared/api-support/kits/mistral-ai/request-transformers";
/** Transforms an incoming request body to one that matches the target API. */
export const transformOutboundPayload: RequestPreprocessor = async (req) => {

View File

@ -14,7 +14,7 @@ import {
flattenAnthropicMessages,
MistralAIChatMessage,
OpenAIChatMessage,
} from "../../../shared/api-schemas";
} from "../../../shared/api-support";
import { APIFormat } from "../../../shared/key-management";
/** If prompt logging is enabled, enqueues the prompt for logging. */

View File

@ -1,62 +0,0 @@
import type { Request } from "express";
import { z } from "zod";
import { APIFormat } from "../key-management";
import {
AnthropicV1TextSchema,
AnthropicV1MessagesSchema,
transformAnthropicTextToAnthropicChat,
transformOpenAIToAnthropicText,
transformOpenAIToAnthropicChat,
} from "./anthropic";
import { OpenAIV1ChatCompletionSchema } from "./openai";
import {
OpenAIV1TextCompletionSchema,
transformOpenAIToOpenAIText,
} from "./openai-text";
import {
OpenAIV1ImagesGenerationSchema,
transformOpenAIToOpenAIImage,
} from "./openai-image";
import {
GoogleAIV1GenerateContentSchema,
transformOpenAIToGoogleAI,
} from "./google-ai";
import { MistralAIV1ChatCompletionsSchema } from "./mistral-ai";
export { OpenAIChatMessage } from "./openai";
export {
AnthropicChatMessage,
AnthropicV1TextSchema,
AnthropicV1MessagesSchema,
flattenAnthropicMessages,
} from "./anthropic";
export { GoogleAIChatMessage } from "./google-ai";
export { MistralAIChatMessage } from "./mistral-ai";
type APIPair = `${APIFormat}->${APIFormat}`;
type TransformerMap = {
[key in APIPair]?: APIFormatTransformer<any>;
};
export type APIFormatTransformer<Z extends z.ZodType<any, any>> = (
req: Request
) => Promise<z.infer<Z>>;
export const API_REQUEST_TRANSFORMERS: TransformerMap = {
"anthropic-text->anthropic-chat": transformAnthropicTextToAnthropicChat,
"openai->anthropic-chat": transformOpenAIToAnthropicChat,
"openai->anthropic-text": transformOpenAIToAnthropicText,
"openai->openai-text": transformOpenAIToOpenAIText,
"openai->openai-image": transformOpenAIToOpenAIImage,
"openai->google-ai": transformOpenAIToGoogleAI,
};
export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
"anthropic-chat": AnthropicV1MessagesSchema,
"anthropic-text": AnthropicV1TextSchema,
openai: OpenAIV1ChatCompletionSchema,
"openai-text": OpenAIV1TextCompletionSchema,
"openai-image": OpenAIV1ImagesGenerationSchema,
"google-ai": GoogleAIV1GenerateContentSchema,
"mistral-ai": MistralAIV1ChatCompletionsSchema,
};

View File

@ -1,58 +0,0 @@
import { z } from "zod";
import {
flattenOpenAIChatMessages,
OpenAIV1ChatCompletionSchema,
} from "./openai";
import { APIFormatTransformer } from "./index";
export const OpenAIV1TextCompletionSchema = z
.object({
model: z
.string()
.max(100)
.regex(
/^gpt-3.5-turbo-instruct/,
"Model must start with 'gpt-3.5-turbo-instruct'"
),
prompt: z.string({
required_error:
"No `prompt` found. Ensure you've set the correct completion endpoint.",
}),
logprobs: z.number().int().nullish().default(null),
echo: z.boolean().optional().default(false),
best_of: z.literal(1).optional(),
stop: z
.union([z.string().max(500), z.array(z.string().max(500)).max(4)])
.optional(),
suffix: z.string().max(1000).optional(),
})
.strip()
.merge(OpenAIV1ChatCompletionSchema.omit({ messages: true, logprobs: true }));
export const transformOpenAIToOpenAIText: APIFormatTransformer<
typeof OpenAIV1TextCompletionSchema
> = async (req) => {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid OpenAI-to-OpenAI-text request"
);
throw result.error;
}
const { messages, ...rest } = result.data;
const prompt = flattenOpenAIChatMessages(messages);
let stops = rest.stop
? Array.isArray(rest.stop)
? rest.stop
: [rest.stop]
: [];
stops.push("\n\nUser:");
stops = [...new Set(stops)];
const transformed = { ...rest, prompt: prompt, stop: stops };
return OpenAIV1TextCompletionSchema.parse(transformed);
};

View File

@ -0,0 +1,84 @@
import type { Request, Response } from "express";
import { z } from "zod";
import { APIFormat } from "../key-management";
import { AnthropicV1MessagesSchema } from "./kits/anthropic-chat/schema";
import { AnthropicV1TextSchema } from "./kits/anthropic-text/schema";
import { transformOpenAIToAnthropicText } from "./kits/anthropic-text/request-transformers";
import {
transformAnthropicTextToAnthropicChat,
transformOpenAIToAnthropicChat,
} from "./kits/anthropic-chat/request-transformers";
import { GoogleAIV1GenerateContentSchema } from "./kits/google-ai/schema";
import { transformOpenAIToGoogleAI } from "./kits/google-ai/request-transformers";
import { MistralAIV1ChatCompletionsSchema } from "./kits/mistral-ai/schema";
import { OpenAIV1ChatCompletionSchema } from "./kits/openai/schema";
import { OpenAIV1ImagesGenerationSchema } from "./kits/openai-image/schema";
import { transformOpenAIToOpenAIImage } from "./kits/openai-image/request-transformers";
import { OpenAIV1TextCompletionSchema } from "./kits/openai-text/schema";
import { transformOpenAIToOpenAIText } from "./kits/openai-text/request-transformers";
export type APIRequestTransformer<Z extends z.ZodType<any, any>> = (
req: Request
) => Promise<z.infer<Z>>;
export type APIResponseTransformer<Z extends z.ZodType<any, any>> = (
res: Response
) => Promise<z.infer<Z>>;
/** Represents a transformation from one API format to another. */
type APITransformation = `${APIFormat}->${APIFormat}`;
type APIRequestTransformerMap = {
[key in APITransformation]?: APIRequestTransformer<any>;
};
type APIResponseTransformerMap = {
[key in APITransformation]?: APIResponseTransformer<any>;
};
export const API_REQUEST_TRANSFORMERS: APIRequestTransformerMap = {
"anthropic-text->anthropic-chat": transformAnthropicTextToAnthropicChat,
"openai->anthropic-chat": transformOpenAIToAnthropicChat,
"openai->anthropic-text": transformOpenAIToAnthropicText,
"openai->openai-text": transformOpenAIToOpenAIText,
"openai->openai-image": transformOpenAIToOpenAIImage,
"openai->google-ai": transformOpenAIToGoogleAI,
};
export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
"anthropic-chat": AnthropicV1MessagesSchema,
"anthropic-text": AnthropicV1TextSchema,
openai: OpenAIV1ChatCompletionSchema,
"openai-text": OpenAIV1TextCompletionSchema,
"openai-image": OpenAIV1ImagesGenerationSchema,
"google-ai": GoogleAIV1GenerateContentSchema,
"mistral-ai": MistralAIV1ChatCompletionsSchema,
};
export { AnthropicChatMessage } from "./kits/anthropic-chat/schema";
export { AnthropicV1MessagesSchema } from "./kits/anthropic-chat/schema";
export { AnthropicV1TextSchema } from "./kits/anthropic-text/schema";
export interface APIFormatKit<T extends APIFormat, P> {
name: T;
/** Zod schema for validating requests in this format. */
requestValidator: z.ZodSchema<any>;
/** Flattens non-sting prompts (such as message arrays) into a single string. */
promptStringifier: (prompt: P) => string;
/** Counts the number of tokens in a prompt. */
promptTokenCounter: (prompt: P, model: string) => Promise<number>;
/** Counts the number of tokens in a completion. */
completionTokenCounter: (
completion: string,
model: string
) => Promise<number>;
/** Functions which transform requests from other formats into this format. */
requestTransformers: APIRequestTransformerMap;
/** Functions which transform responses from this format into other formats. */
responseTransformers: APIResponseTransformerMap;
}
export { GoogleAIChatMessage } from "./kits/google-ai";
export { MistralAIChatMessage } from "./kits/mistral-ai";
export { OpenAIChatMessage } from "./kits/openai/schema";
export { flattenAnthropicMessages } from "./kits/anthropic-chat/stringifier";

View File

@ -0,0 +1,4 @@
# API Kits
This directory contains "kits" for each supported language model API. Each kit implements the `APIFormatKit` interface and provides functionality that the proxy application needs to be able to validate requests, transform prompts and responses, tokenize text, and so forth.
## Structure

View File

@ -1,98 +1,23 @@
import { z } from "zod";
import { config } from "../../config";
import { BadRequestError } from "../errors";
import {
flattenOpenAIMessageContent,
OpenAIChatMessage,
OpenAIV1ChatCompletionSchema,
} from "./openai";
import { APIFormatTransformer } from "./index";
import { AnthropicChatMessage, AnthropicV1MessagesSchema } from "./schema";
import { AnthropicV1TextSchema, APIRequestTransformer, OpenAIChatMessage } from "../../index";
import { BadRequestError } from "../../../errors";
const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic;
import { OpenAIV1ChatCompletionSchema } from "../openai/schema";
const AnthropicV1BaseSchema = z
.object({
model: z.string().max(100),
stop_sequences: z.array(z.string().max(500)).optional(),
stream: z.boolean().optional().default(false),
temperature: z.coerce.number().optional().default(1),
top_k: z.coerce.number().optional(),
top_p: z.coerce.number().optional(),
metadata: z.object({ user_id: z.string().optional() }).optional(),
})
.strip();
/**
* Represents the union of all content types without the `string` shorthand
* for `text` content.
*/
type AnthropicChatMessageContentWithoutString = Exclude<
AnthropicChatMessage["content"],
string
>;
/** Represents a message with all shorthand `string` content expanded. */
type ConvertedAnthropicChatMessage = AnthropicChatMessage & {
content: AnthropicChatMessageContentWithoutString;
};
// https://docs.anthropic.com/claude/reference/complete_post [deprecated]
export const AnthropicV1TextSchema = AnthropicV1BaseSchema.merge(
z.object({
prompt: z.string(),
max_tokens_to_sample: z.coerce
.number()
.int()
.transform((v) => Math.min(v, CLAUDE_OUTPUT_MAX)),
})
);
const AnthropicV1MessageMultimodalContentSchema = z.array(
z.union([
z.object({ type: z.literal("text"), text: z.string() }),
z.object({
type: z.literal("image"),
source: z.object({
type: z.literal("base64"),
media_type: z.string().max(100),
data: z.string(),
}),
}),
])
);
// https://docs.anthropic.com/claude/reference/messages_post
export const AnthropicV1MessagesSchema = AnthropicV1BaseSchema.merge(
z.object({
messages: z.array(
z.object({
role: z.enum(["user", "assistant"]),
content: z.union([
z.string(),
AnthropicV1MessageMultimodalContentSchema,
]),
})
),
max_tokens: z
.number()
.int()
.transform((v) => Math.min(v, CLAUDE_OUTPUT_MAX)),
system: z.string().optional(),
})
);
export type AnthropicChatMessage = z.infer<
typeof AnthropicV1MessagesSchema
>["messages"][0];
function openAIMessagesToClaudeTextPrompt(messages: OpenAIChatMessage[]) {
return (
messages
.map((m) => {
let role: string = m.role;
if (role === "assistant") {
role = "Assistant";
} else if (role === "system") {
role = "System";
} else if (role === "user") {
role = "Human";
}
const name = m.name?.trim();
const content = flattenOpenAIMessageContent(m.content);
// https://console.anthropic.com/docs/prompt-design
// `name` isn't supported by Anthropic but we can still try to use it.
return `\n\n${role}: ${name ? `(as ${name}) ` : ""}${content}`;
})
.join("") + "\n\nAssistant:"
);
}
export const transformOpenAIToAnthropicChat: APIFormatTransformer<
export const transformOpenAIToAnthropicChat: APIRequestTransformer<
typeof AnthropicV1MessagesSchema
> = async (req) => {
const { body } = req;
@ -127,53 +52,11 @@ export const transformOpenAIToAnthropicChat: APIFormatTransformer<
};
};
export const transformOpenAIToAnthropicText: APIFormatTransformer<
typeof AnthropicV1TextSchema
> = async (req) => {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid OpenAI-to-Anthropic Text request"
);
throw result.error;
}
req.headers["anthropic-version"] = "2023-06-01";
const { messages, ...rest } = result.data;
const prompt = openAIMessagesToClaudeTextPrompt(messages);
let stops = rest.stop
? Array.isArray(rest.stop)
? rest.stop
: [rest.stop]
: [];
// Recommended by Anthropic
stops.push("\n\nHuman:");
// Helps with jailbreak prompts that send fake system messages and multi-bot
// chats that prefix bot messages with "System: Respond as <bot name>".
stops.push("\n\nSystem:");
// Remove duplicates
stops = [...new Set(stops)];
return {
model: rest.model,
prompt: prompt,
max_tokens_to_sample: rest.max_tokens,
stop_sequences: stops,
stream: rest.stream,
temperature: rest.temperature,
top_p: rest.top_p,
};
};
/**
* Converts an older Anthropic Text Completion prompt to the newer Messages API
* by splitting the flat text into messages.
*/
export const transformAnthropicTextToAnthropicChat: APIFormatTransformer<
export const transformAnthropicTextToAnthropicChat: APIRequestTransformer<
typeof AnthropicV1MessagesSchema
> = async (req) => {
const { body } = req;
@ -255,39 +138,6 @@ function validateAnthropicTextPrompt(prompt: string) {
}
}
export function flattenAnthropicMessages(
messages: AnthropicChatMessage[]
): string {
return messages
.map((msg) => {
const name = msg.role === "user" ? "\n\nHuman: " : "\n\nAssistant: ";
const parts = Array.isArray(msg.content)
? msg.content
: [{ type: "text", text: msg.content }];
return `${name}: ${parts
.map((part) =>
part.type === "text"
? part.text
: `[Omitted multimodal content of type ${part.type}]`
)
.join("\n")}`;
})
.join("\n\n");
}
/**
* Represents the union of all content types without the `string` shorthand
* for `text` content.
*/
type AnthropicChatMessageContentWithoutString = Exclude<
AnthropicChatMessage["content"],
string
>;
/** Represents a message with all shorthand `string` content expanded. */
type ConvertedAnthropicChatMessage = AnthropicChatMessage & {
content: AnthropicChatMessageContentWithoutString;
};
function openAIMessagesToClaudeChatPrompt(messages: OpenAIChatMessage[]): {
messages: AnthropicChatMessage[];
system: string;

View File

@ -0,0 +1,52 @@
import { z } from "zod";
import { config } from "../../../../config";
const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic;
export const AnthropicV1BaseSchema = z
.object({
model: z.string().max(100),
stop_sequences: z.array(z.string().max(500)).optional(),
stream: z.boolean().optional().default(false),
temperature: z.coerce.number().optional().default(1),
top_k: z.coerce.number().optional(),
top_p: z.coerce.number().optional(),
metadata: z.object({ user_id: z.string().optional() }).optional(),
})
.strip();
const AnthropicV1MessageMultimodalContentSchema = z.array(
z.union([
z.object({ type: z.literal("text"), text: z.string() }),
z.object({
type: z.literal("image"),
source: z.object({
type: z.literal("base64"),
media_type: z.string().max(100),
data: z.string(),
}),
}),
])
);
// https://docs.anthropic.com/claude/reference/messages_post
export const AnthropicV1MessagesSchema = AnthropicV1BaseSchema.merge(
z.object({
messages: z.array(
z.object({
role: z.enum(["user", "assistant"]),
content: z.union([
z.string(),
AnthropicV1MessageMultimodalContentSchema,
]),
})
),
max_tokens: z
.number()
.int()
.transform((v) => Math.min(v, CLAUDE_OUTPUT_MAX)),
system: z.string().optional(),
})
);
export type AnthropicChatMessage = z.infer<
typeof AnthropicV1MessagesSchema
>["messages"][0];

View File

@ -0,0 +1,21 @@
import { AnthropicChatMessage } from "./schema";
export function flattenAnthropicMessages(
messages: AnthropicChatMessage[]
): string {
return messages
.map((msg) => {
const name = msg.role === "user" ? "\n\nHuman: " : "\n\nAssistant: ";
const parts = Array.isArray(msg.content)
? msg.content
: [{ type: "text", text: msg.content }];
return `${name}: ${parts
.map((part) =>
part.type === "text"
? part.text
: `[Omitted multimodal content of type ${part.type}]`
)
.join("\n")}`;
})
.join("\n\n");
}

View File

@ -0,0 +1,73 @@
import {
AnthropicV1TextSchema,
APIRequestTransformer,
OpenAIChatMessage,
} from "../../index";
import { OpenAIV1ChatCompletionSchema } from "../openai/schema";
import { flattenOpenAIMessageContent } from "../openai/stringifier";
export const transformOpenAIToAnthropicText: APIRequestTransformer<
typeof AnthropicV1TextSchema
> = async (req) => {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid OpenAI-to-Anthropic Text request"
);
throw result.error;
}
req.headers["anthropic-version"] = "2023-06-01";
const { messages, ...rest } = result.data;
const prompt = openAIMessagesToClaudeTextPrompt(messages);
let stops = rest.stop
? Array.isArray(rest.stop)
? rest.stop
: [rest.stop]
: [];
// Recommended by Anthropic
stops.push("\n\nHuman:");
// Helps with jailbreak prompts that send fake system messages and multi-bot
// chats that prefix bot messages with "System: Respond as <bot name>".
stops.push("\n\nSystem:");
// Remove duplicates
stops = [...new Set(stops)];
return {
model: rest.model,
prompt: prompt,
max_tokens_to_sample: rest.max_tokens,
stop_sequences: stops,
stream: rest.stream,
temperature: rest.temperature,
top_p: rest.top_p,
};
};
function openAIMessagesToClaudeTextPrompt(messages: OpenAIChatMessage[]) {
return (
messages
.map((m) => {
let role: string = m.role;
if (role === "assistant") {
role = "Assistant";
} else if (role === "system") {
role = "System";
} else if (role === "user") {
role = "Human";
}
const name = m.name?.trim();
const content = flattenOpenAIMessageContent(m.content);
// https://console.anthropic.com/docs/prompt-design
// `name` isn't supported by Anthropic but we can still try to use it.
return `\n\n${role}: ${name ? `(as ${name}) ` : ""}${content}`;
})
.join("") + "\n\nAssistant:"
);
}

View File

@ -0,0 +1,16 @@
import { z } from "zod";
import { AnthropicV1BaseSchema } from "../anthropic-chat/schema";
import { config } from "../../../../config";
const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic;
// https://docs.anthropic.com/claude/reference/complete_post [deprecated]
export const AnthropicV1TextSchema = AnthropicV1BaseSchema.merge(
z.object({
prompt: z.string(),
max_tokens_to_sample: z.coerce
.number()
.int()
.transform((v) => Math.min(v, CLAUDE_OUTPUT_MAX)),
})
);

View File

@ -0,0 +1 @@
export { GoogleAIChatMessage } from "./schema";

View File

@ -1,43 +1,11 @@
import { z } from "zod";
import {
flattenOpenAIMessageContent,
OpenAIV1ChatCompletionSchema,
} from "./openai";
import { APIFormatTransformer } from "./index";
import { APIRequestTransformer, GoogleAIChatMessage } from "../../index";
import { GoogleAIV1GenerateContentSchema } from "./schema";
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateContent
export const GoogleAIV1GenerateContentSchema = z
.object({
model: z.string().max(100), //actually specified in path but we need it for the router
stream: z.boolean().optional().default(false), // also used for router
contents: z.array(
z.object({
parts: z.array(z.object({ text: z.string() })),
role: z.enum(["user", "model"]),
})
),
tools: z.array(z.object({})).max(0).optional(),
safetySettings: z.array(z.object({})).max(0).optional(),
generationConfig: z.object({
temperature: z.number().optional(),
maxOutputTokens: z.coerce
.number()
.int()
.optional()
.default(16)
.transform((v) => Math.min(v, 1024)), // TODO: Add config
candidateCount: z.literal(1).optional(),
topP: z.number().optional(),
topK: z.number().optional(),
stopSequences: z.array(z.string().max(500)).max(5).optional(),
}),
})
.strip();
export type GoogleAIChatMessage = z.infer<
typeof GoogleAIV1GenerateContentSchema
>["contents"][0];
import { OpenAIV1ChatCompletionSchema } from "../openai/schema";
export const transformOpenAIToGoogleAI: APIFormatTransformer<
import { flattenOpenAIMessageContent } from "../openai/stringifier";
export const transformOpenAIToGoogleAI: APIRequestTransformer<
typeof GoogleAIV1GenerateContentSchema
> = async (req) => {
const { body } = req;

View File

@ -0,0 +1,34 @@
import { z } from "zod";
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateContent
export const GoogleAIV1GenerateContentSchema = z
.object({
model: z.string().max(100), //actually specified in path but we need it for the router
stream: z.boolean().optional().default(false), // also used for router
contents: z.array(
z.object({
parts: z.array(z.object({ text: z.string() })),
role: z.enum(["user", "model"]),
})
),
tools: z.array(z.object({})).max(0).optional(),
safetySettings: z.array(z.object({})).max(0).optional(),
generationConfig: z.object({
temperature: z.number().optional(),
maxOutputTokens: z.coerce
.number()
.int()
.optional()
.default(16)
.transform((v) => Math.min(v, 1024)), // TODO: Add config
candidateCount: z.literal(1).optional(),
topP: z.number().optional(),
topK: z.number().optional(),
stopSequences: z.array(z.string().max(500)).max(5).optional(),
}),
})
.strip();
export type GoogleAIChatMessage = z.infer<
typeof GoogleAIV1GenerateContentSchema
>["contents"][0];

View File

@ -0,0 +1 @@
export { MistralAIChatMessage } from "./schema";

View File

@ -1,29 +1,4 @@
import { z } from "zod";
import { OPENAI_OUTPUT_MAX } from "./openai";
// https://docs.mistral.ai/api#operation/createChatCompletion
export const MistralAIV1ChatCompletionsSchema = z.object({
model: z.string(),
messages: z.array(
z.object({
role: z.enum(["system", "user", "assistant"]),
content: z.string(),
})
),
temperature: z.number().optional().default(0.7),
top_p: z.number().optional().default(1),
max_tokens: z.coerce
.number()
.int()
.nullish()
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
stream: z.boolean().optional().default(false),
safe_prompt: z.boolean().optional().default(false),
random_seed: z.number().int().optional(),
});
export type MistralAIChatMessage = z.infer<
typeof MistralAIV1ChatCompletionsSchema
>["messages"][0];
import { MistralAIChatMessage } from "./schema";
export function fixMistralPrompt(
messages: MistralAIChatMessage[]

View File

@ -0,0 +1,28 @@
// https://docs.mistral.ai/api#operation/createChatCompletion
import { z } from "zod";
import { OPENAI_OUTPUT_MAX } from "../openai/schema";
export const MistralAIV1ChatCompletionsSchema = z.object({
model: z.string(),
messages: z.array(
z.object({
role: z.enum(["system", "user", "assistant"]),
content: z.string(),
})
),
temperature: z.number().optional().default(0.7),
top_p: z.number().optional().default(1),
max_tokens: z.coerce
.number()
.int()
.nullish()
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
stream: z.boolean().optional().default(false),
safe_prompt: z.boolean().optional().default(false),
random_seed: z.number().int().optional(),
});
export type MistralAIChatMessage = z.infer<
typeof MistralAIV1ChatCompletionsSchema
>["messages"][0];

View File

@ -1,26 +1,9 @@
import { z } from "zod";
import { OpenAIV1ChatCompletionSchema } from "./openai";
import { APIFormatTransformer } from "./index";
/* Takes the last chat message and uses it verbatim as the image prompt. */
import { APIRequestTransformer } from "../../index";
import { OpenAIV1ImagesGenerationSchema } from "./schema";
import { OpenAIV1ChatCompletionSchema } from "../openai/schema";
// https://platform.openai.com/docs/api-reference/images/create
export const OpenAIV1ImagesGenerationSchema = z
.object({
prompt: z.string().max(4000),
model: z.string().max(100).optional(),
quality: z.enum(["standard", "hd"]).optional().default("standard"),
n: z.number().int().min(1).max(4).optional().default(1),
response_format: z.enum(["url", "b64_json"]).optional(),
size: z
.enum(["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"])
.optional()
.default("1024x1024"),
style: z.enum(["vivid", "natural"]).optional().default("vivid"),
user: z.string().max(500).optional(),
})
.strip();
// Takes the last chat message and uses it verbatim as the image prompt.
export const transformOpenAIToOpenAIImage: APIFormatTransformer<
export const transformOpenAIToOpenAIImage: APIRequestTransformer<
typeof OpenAIV1ImagesGenerationSchema
> = async (req) => {
const { body } = req;

View File

@ -0,0 +1,18 @@
// https://platform.openai.com/docs/api-reference/images/create
import { z } from "zod";
export const OpenAIV1ImagesGenerationSchema = z
.object({
prompt: z.string().max(4000),
model: z.string().max(100).optional(),
quality: z.enum(["standard", "hd"]).optional().default("standard"),
n: z.number().int().min(1).max(4).optional().default(1),
response_format: z.enum(["url", "b64_json"]).optional(),
size: z
.enum(["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"])
.optional()
.default("1024x1024"),
style: z.enum(["vivid", "natural"]).optional().default("vivid"),
user: z.string().max(500).optional(),
})
.strip();

View File

@ -0,0 +1,33 @@
import { APIRequestTransformer } from "../../index";
import { OpenAIV1TextCompletionSchema } from "./schema";
import { OpenAIV1ChatCompletionSchema } from "../openai/schema";
import { flattenOpenAIChatMessages } from "../openai/stringifier";
export const transformOpenAIToOpenAIText: APIRequestTransformer<
typeof OpenAIV1TextCompletionSchema
> = async (req) => {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid OpenAI-to-OpenAI-text request"
);
throw result.error;
}
const { messages, ...rest } = result.data;
const prompt = flattenOpenAIChatMessages(messages);
let stops = rest.stop
? Array.isArray(rest.stop)
? rest.stop
: [rest.stop]
: [];
stops.push("\n\nUser:");
stops = [...new Set(stops)];
const transformed = { ...rest, prompt: prompt, stop: stops };
return OpenAIV1TextCompletionSchema.parse(transformed);
};

View File

@ -0,0 +1,26 @@
import { z } from "zod";
import { OpenAIV1ChatCompletionSchema } from "../openai/schema";
export const OpenAIV1TextCompletionSchema = z
.object({
model: z
.string()
.max(100)
.regex(
/^gpt-3.5-turbo-instruct/,
"Model must start with 'gpt-3.5-turbo-instruct'"
),
prompt: z.string({
required_error:
"No `prompt` found. Ensure you've set the correct completion endpoint.",
}),
logprobs: z.number().int().nullish().default(null),
echo: z.boolean().optional().default(false),
best_of: z.literal(1).optional(),
stop: z
.union([z.string().max(500), z.array(z.string().max(500)).max(4)])
.optional(),
suffix: z.string().max(1000).optional(),
})
.strip()
.merge(OpenAIV1ChatCompletionSchema.omit({ messages: true, logprobs: true }));

View File

@ -0,0 +1,13 @@
import { APIFormatKit } from "../../index";
import { OpenAIChatMessage, OpenAIV1ChatCompletionSchema } from "./schema";
import { flattenOpenAIChatMessages } from "./stringifier";
import { getOpenAITokenCount } from "./tokenizer";
const kit: APIFormatKit<"openai", OpenAIChatMessage[]> = {
name: "openai",
requestValidator: OpenAIV1ChatCompletionSchema,
// We never transform from other formats into OpenAI format.
requestTransformers: {},
promptStringifier: flattenOpenAIChatMessages,
promptTokenCounter: getOpenAITokenCount,
};

View File

@ -1,8 +1,7 @@
import { z } from "zod";
import { config } from "../../config";
import { config } from "../../../../config";
export const OPENAI_OUTPUT_MAX = config.maxOutputTokensOpenAI;
// https://platform.openai.com/docs/api-reference/chat/create
const OpenAIV1ChatContentArraySchema = z.array(
z.union([
@ -81,53 +80,3 @@ export const OpenAIV1ChatCompletionSchema = z
export type OpenAIChatMessage = z.infer<
typeof OpenAIV1ChatCompletionSchema
>["messages"][0];
export function flattenOpenAIMessageContent(
content: OpenAIChatMessage["content"]
): string {
return Array.isArray(content)
? content
.map((contentItem) => {
if ("text" in contentItem) return contentItem.text;
if ("image_url" in contentItem) return "[ Uploaded Image Omitted ]";
})
.join("\n")
: content;
}
export function flattenOpenAIChatMessages(messages: OpenAIChatMessage[]) {
// Temporary to allow experimenting with prompt strategies
const PROMPT_VERSION: number = 1;
switch (PROMPT_VERSION) {
case 1:
return (
messages
.map((m) => {
// Claude-style human/assistant turns
let role: string = m.role;
if (role === "assistant") {
role = "Assistant";
} else if (role === "system") {
role = "System";
} else if (role === "user") {
role = "User";
}
return `\n\n${role}: ${flattenOpenAIMessageContent(m.content)}`;
})
.join("") + "\n\nAssistant:"
);
case 2:
return messages
.map((m) => {
// Claude without prefixes (except system) and no Assistant priming
let role: string = "";
if (role === "system") {
role = "System: ";
}
return `\n\n${role}${flattenOpenAIMessageContent(m.content)}`;
})
.join("");
default:
throw new Error(`Unknown prompt version: ${PROMPT_VERSION}`);
}
}

View File

@ -0,0 +1,33 @@
import { OpenAIChatMessage } from "./schema";
export function flattenOpenAIChatMessages(messages: OpenAIChatMessage[]) {
return (
messages
.map((m) => {
// Claude-style human/assistant turns
let role: string = m.role;
if (role === "assistant") {
role = "Assistant";
} else if (role === "system") {
role = "System";
} else if (role === "user") {
role = "User";
}
return `\n\n${role}: ${flattenOpenAIMessageContent(m.content)}`;
})
.join("") + "\n\nAssistant:"
);
}
export function flattenOpenAIMessageContent(
content: OpenAIChatMessage["content"],
): string {
return Array.isArray(content)
? content
.map((contentItem) => {
if ("text" in contentItem) return contentItem.text;
if ("image_url" in contentItem) return "[ Uploaded Image Omitted ]";
})
.join("\n")
: content;
}

View File

@ -0,0 +1,154 @@
import { Tiktoken } from "tiktoken/lite";
import cl100k_base from "tiktoken/encoders/cl100k_base.json";
import { logger } from "../../../../logger";
import { libSharp } from "../../../file-storage";
import { OpenAIChatMessage } from "./schema";
const GPT4_VISION_SYSTEM_PROMPT_SIZE = 170;
const log = logger.child({ module: "tokenizer", service: "openai" });
export const encoder = new Tiktoken(
cl100k_base.bpe_ranks,
cl100k_base.special_tokens,
cl100k_base.pat_str
);
export async function getOpenAITokenCount(
prompt: string | OpenAIChatMessage[],
model: string
) {
if (typeof prompt === "string") {
return getTextTokenCount(prompt);
}
const oldFormatting = model.startsWith("turbo-0301");
const vision = model.includes("vision");
const tokensPerMessage = oldFormatting ? 4 : 3;
const tokensPerName = oldFormatting ? -1 : 1; // older formatting replaces role with name if name is present
let numTokens = vision ? GPT4_VISION_SYSTEM_PROMPT_SIZE : 0;
for (const message of prompt) {
numTokens += tokensPerMessage;
for (const key of Object.keys(message)) {
{
let textContent: string = "";
const value = message[key as keyof OpenAIChatMessage];
if (!value) continue;
if (Array.isArray(value)) {
for (const item of value) {
if (item.type === "text") {
textContent += item.text;
} else if (["image", "image_url"].includes(item.type)) {
const { url, detail } = item.image_url;
const cost = await getGpt4VisionTokenCost(url, detail);
numTokens += cost ?? 0;
}
}
} else {
textContent = value;
}
if (textContent.length > 800000 || numTokens > 200000) {
throw new Error("Content is too large to tokenize.");
}
numTokens += encoder.encode(textContent).length;
if (key === "name") {
numTokens += tokensPerName;
}
}
}
}
numTokens += 3; // every reply is primed with <|start|>assistant<|message|>
return { tokenizer: "tiktoken", token_count: numTokens };
}
async function getGpt4VisionTokenCost(
url: string,
detail: "auto" | "low" | "high" = "auto"
) {
// For now we do not allow remote images as the proxy would have to download
// them, which is a potential DoS vector.
if (!url.startsWith("data:image/")) {
throw new Error(
"Remote images are not supported. Add the image to your prompt as a base64 data URL."
);
}
const base64Data = url.split(",")[1];
const buffer = Buffer.from(base64Data, "base64");
const image = libSharp(buffer);
const metadata = await image.metadata();
if (!metadata || !metadata.width || !metadata.height) {
throw new Error("Prompt includes an image that could not be parsed");
}
const { width, height } = metadata;
let selectedDetail: "low" | "high";
if (detail === "auto") {
const threshold = 512 * 512;
const imageSize = width * height;
selectedDetail = imageSize > threshold ? "high" : "low";
} else {
selectedDetail = detail;
}
// https://platform.openai.com/docs/guides/vision/calculating-costs
if (selectedDetail === "low") {
log.info(
{ width, height, tokens: 85 },
"Using fixed GPT-4-Vision token cost for low detail image"
);
return 85;
}
let newWidth = width;
let newHeight = height;
if (width > 2048 || height > 2048) {
const aspectRatio = width / height;
if (width > height) {
newWidth = 2048;
newHeight = Math.round(2048 / aspectRatio);
} else {
newHeight = 2048;
newWidth = Math.round(2048 * aspectRatio);
}
}
if (newWidth < newHeight) {
newHeight = Math.round((newHeight / newWidth) * 768);
newWidth = 768;
} else {
newWidth = Math.round((newWidth / newHeight) * 768);
newHeight = 768;
}
const tiles = Math.ceil(newWidth / 512) * Math.ceil(newHeight / 512);
const tokens = 170 * tiles + 85;
log.info(
{ width, height, newWidth, newHeight, tiles, tokens },
"Calculated GPT-4-Vision token cost for high detail image"
);
return tokens;
}
export function getTextTokenCount(prompt: string) {
if (prompt.length > 500000) {
return {
tokenizer: "length fallback",
token_count: 100000,
};
}
return {
tokenizer: "tiktoken",
token_count: encoder.encode(prompt).length,
};
}

View File

@ -1,6 +1,6 @@
import { getTokenizer } from "@anthropic-ai/tokenizer";
import { Tiktoken } from "tiktoken/lite";
import { AnthropicChatMessage } from "../api-schemas";
import { AnthropicChatMessage } from "../api-support";
import { libSharp } from "../file-storage";
import { logger } from "../../logger";

View File

@ -1,5 +1,5 @@
import * as tokenizer from "./mistral-tokenizer-js";
import { MistralAIChatMessage } from "../api-schemas";
import { MistralAIChatMessage } from "../api-support";
export function init() {
tokenizer.initializemistralTokenizer();

View File

@ -1,166 +1,9 @@
import { Tiktoken } from "tiktoken/lite";
import cl100k_base from "tiktoken/encoders/cl100k_base.json";
import { logger } from "../../logger";
import { libSharp } from "../file-storage";
import { GoogleAIChatMessage, OpenAIChatMessage } from "../api-schemas";
const log = logger.child({ module: "tokenizer", service: "openai" });
const GPT4_VISION_SYSTEM_PROMPT_SIZE = 170;
let encoder: Tiktoken;
export function init() {
encoder = new Tiktoken(
cl100k_base.bpe_ranks,
cl100k_base.special_tokens,
cl100k_base.pat_str
);
return true;
}
import { GoogleAIChatMessage } from "../api-support";
import { encoder, getTextTokenCount } from "../api-support/kits/openai/tokenizer";
// Tested against:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
export async function getTokenCount(
prompt: string | OpenAIChatMessage[],
model: string
) {
if (typeof prompt === "string") {
return getTextTokenCount(prompt);
}
const oldFormatting = model.startsWith("turbo-0301");
const vision = model.includes("vision");
const tokensPerMessage = oldFormatting ? 4 : 3;
const tokensPerName = oldFormatting ? -1 : 1; // older formatting replaces role with name if name is present
let numTokens = vision ? GPT4_VISION_SYSTEM_PROMPT_SIZE : 0;
for (const message of prompt) {
numTokens += tokensPerMessage;
for (const key of Object.keys(message)) {
{
let textContent: string = "";
const value = message[key as keyof OpenAIChatMessage];
if (!value) continue;
if (Array.isArray(value)) {
for (const item of value) {
if (item.type === "text") {
textContent += item.text;
} else if (["image", "image_url"].includes(item.type)) {
const { url, detail } = item.image_url;
const cost = await getGpt4VisionTokenCost(url, detail);
numTokens += cost ?? 0;
}
}
} else {
textContent = value;
}
if (textContent.length > 800000 || numTokens > 200000) {
throw new Error("Content is too large to tokenize.");
}
numTokens += encoder.encode(textContent).length;
if (key === "name") {
numTokens += tokensPerName;
}
}
}
}
numTokens += 3; // every reply is primed with <|start|>assistant<|message|>
return { tokenizer: "tiktoken", token_count: numTokens };
}
async function getGpt4VisionTokenCost(
url: string,
detail: "auto" | "low" | "high" = "auto"
) {
// For now we do not allow remote images as the proxy would have to download
// them, which is a potential DoS vector.
if (!url.startsWith("data:image/")) {
throw new Error(
"Remote images are not supported. Add the image to your prompt as a base64 data URL."
);
}
const base64Data = url.split(",")[1];
const buffer = Buffer.from(base64Data, "base64");
const image = libSharp(buffer);
const metadata = await image.metadata();
if (!metadata || !metadata.width || !metadata.height) {
throw new Error("Prompt includes an image that could not be parsed");
}
const { width, height } = metadata;
let selectedDetail: "low" | "high";
if (detail === "auto") {
const threshold = 512 * 512;
const imageSize = width * height;
selectedDetail = imageSize > threshold ? "high" : "low";
} else {
selectedDetail = detail;
}
// https://platform.openai.com/docs/guides/vision/calculating-costs
if (selectedDetail === "low") {
log.info(
{ width, height, tokens: 85 },
"Using fixed GPT-4-Vision token cost for low detail image"
);
return 85;
}
let newWidth = width;
let newHeight = height;
if (width > 2048 || height > 2048) {
const aspectRatio = width / height;
if (width > height) {
newWidth = 2048;
newHeight = Math.round(2048 / aspectRatio);
} else {
newHeight = 2048;
newWidth = Math.round(2048 * aspectRatio);
}
}
if (newWidth < newHeight) {
newHeight = Math.round((newHeight / newWidth) * 768);
newWidth = 768;
} else {
newWidth = Math.round((newWidth / newHeight) * 768);
newHeight = 768;
}
const tiles = Math.ceil(newWidth / 512) * Math.ceil(newHeight / 512);
const tokens = 170 * tiles + 85;
log.info(
{ width, height, newWidth, newHeight, tiles, tokens },
"Calculated GPT-4-Vision token cost for high detail image"
);
return tokens;
}
function getTextTokenCount(prompt: string) {
if (prompt.length > 500000) {
return {
tokenizer: "length fallback",
token_count: 100000,
};
}
return {
tokenizer: "tiktoken",
token_count: encoder.encode(prompt).length,
};
}
// Model Resolution Price
// DALL·E 3 1024×1024 $0.040 / image
// 1024×1792, 1792×1024 $0.080 / image

View File

@ -7,8 +7,7 @@ import {
import {
estimateGoogleAITokenCount,
getOpenAIImageCost,
getTokenCount as getOpenAITokenCount,
init as initOpenAi,
} from "./openai";
import {
getTokenCount as getMistralAITokenCount,
@ -20,7 +19,8 @@ import {
GoogleAIChatMessage,
MistralAIChatMessage,
OpenAIChatMessage,
} from "../api-schemas";
} from "../api-support";
import { getOpenAITokenCount as getOpenAITokenCount, init as initOpenAi } from "../api-support/kits/openai/tokenizer";
export async function init() {
initClaude();