refactors api transformers and adds oai->anthropic chat api translation

This commit is contained in:
nai-degen 2024-03-08 20:59:19 -06:00
parent 8d84f289b2
commit fab404b232
12 changed files with 440 additions and 142 deletions

View File

@ -83,17 +83,19 @@ const anthropicResponseHandler: ProxyResHandlerWithBody = async (
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`; body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
} }
if (req.inboundApi === "openai") { switch (`${req.inboundApi}<-${req.outboundApi}`) {
req.log.info("Transforming Anthropic text to OpenAI format"); case "openai<-anthropic-text":
body = transformAnthropicTextResponseToOpenAI(body, req); req.log.info("Transforming Anthropic Text back to OpenAI format");
} body = transformAnthropicTextResponseToOpenAI(body, req);
break;
if ( case "openai<-anthropic-chat":
req.inboundApi === "anthropic-text" && req.log.info("Transforming Anthropic Chat back to OpenAI format");
req.outboundApi === "anthropic-chat" body = transformAnthropicChatResponseToOpenAI(body);
) { break;
req.log.info("Transforming Anthropic text to Anthropic chat format"); case "anthropic-text<-anthropic-chat":
body = transformAnthropicChatResponseToAnthropicText(body); req.log.info("Transforming Anthropic Chat back to Anthropic chat format");
body = transformAnthropicChatResponseToAnthropicText(body);
break;
} }
if (req.tokenizerInfo) { if (req.tokenizerInfo) {
@ -103,17 +105,23 @@ const anthropicResponseHandler: ProxyResHandlerWithBody = async (
res.status(200).json(body); res.status(200).json(body);
}; };
function flattenChatResponse(
content: { type: string; text: string }[]
): string {
return content
.map((part: { type: string; text: string }) =>
part.type === "text" ? part.text : ""
)
.join("\n");
}
export function transformAnthropicChatResponseToAnthropicText( export function transformAnthropicChatResponseToAnthropicText(
anthropicBody: Record<string, any> anthropicBody: Record<string, any>
): Record<string, any> { ): Record<string, any> {
return { return {
type: "completion", type: "completion",
id: "trans-" + anthropicBody.id, id: "ant-" + anthropicBody.id,
completion: anthropicBody.content completion: flattenChatResponse(anthropicBody.content),
.map((part: { type: string; text: string }) =>
part.type === "text" ? part.text : ""
)
.join(""),
stop_reason: anthropicBody.stop_reason, stop_reason: anthropicBody.stop_reason,
stop: anthropicBody.stop_sequence, stop: anthropicBody.stop_sequence,
model: anthropicBody.model, model: anthropicBody.model,
@ -155,6 +163,28 @@ function transformAnthropicTextResponseToOpenAI(
}; };
} }
function transformAnthropicChatResponseToOpenAI(
anthropicBody: Record<string, any>
): Record<string, any> {
return {
id: "ant-" + anthropicBody.id,
object: "chat.completion",
created: Date.now(),
model: anthropicBody.model,
usage: anthropicBody.usage,
choices: [
{
message: {
role: "assistant",
content: flattenChatResponse(anthropicBody.content),
},
finish_reason: anthropicBody.stop_reason,
index: 0,
},
],
};
}
const anthropicProxy = createQueueMiddleware({ const anthropicProxy = createQueueMiddleware({
proxyMiddleware: createProxyMiddleware({ proxyMiddleware: createProxyMiddleware({
target: "https://api.anthropic.com", target: "https://api.anthropic.com",
@ -178,6 +208,9 @@ const anthropicProxy = createQueueMiddleware({
if (isText && pathname === "/v1/chat/completions") { if (isText && pathname === "/v1/chat/completions") {
req.url = "/v1/complete"; req.url = "/v1/complete";
} }
if (isChat && pathname === "/v1/chat/completions") {
req.url = "/v1/messages";
}
if (isChat && ["sonnet", "opus"].includes(req.params.type)) { if (isChat && ["sonnet", "opus"].includes(req.params.type)) {
req.url = "/v1/messages"; req.url = "/v1/messages";
} }
@ -202,7 +235,7 @@ const textToChatPreprocessor = createPreprocessorMiddleware({
* Routes text completion prompts to anthropic-chat if they need translation * Routes text completion prompts to anthropic-chat if they need translation
* (claude-3 based models do not support the old text completion endpoint). * (claude-3 based models do not support the old text completion endpoint).
*/ */
const claudeTextCompletionRouter: RequestHandler = (req, res, next) => { const preprocessAnthropicTextRequest: RequestHandler = (req, res, next) => {
if (req.body.model?.startsWith("claude-3")) { if (req.body.model?.startsWith("claude-3")) {
textToChatPreprocessor(req, res, next); textToChatPreprocessor(req, res, next);
} else { } else {
@ -210,15 +243,33 @@ const claudeTextCompletionRouter: RequestHandler = (req, res, next) => {
} }
}; };
const oaiToTextPreprocessor = createPreprocessorMiddleware({
inApi: "openai",
outApi: "anthropic-text",
service: "anthropic",
});
const oaiToChatPreprocessor = createPreprocessorMiddleware({
inApi: "openai",
outApi: "anthropic-chat",
service: "anthropic",
});
/**
* Routes an OpenAI prompt to either the legacy Claude text completion endpoint
* or the new Claude chat completion endpoint, based on the requested model.
*/
const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
maybeReassignModel(req);
if (req.body.model?.includes("claude-3")) {
oaiToChatPreprocessor(req, res, next);
} else {
oaiToTextPreprocessor(req, res, next);
}
};
const anthropicRouter = Router(); const anthropicRouter = Router();
anthropicRouter.get("/v1/models", handleModelRequest); anthropicRouter.get("/v1/models", handleModelRequest);
// Anthropic text completion endpoint. Dynamic routing based on model.
anthropicRouter.post(
"/v1/complete",
ipLimiter,
claudeTextCompletionRouter,
anthropicProxy
);
// Native Anthropic chat completion endpoint. // Native Anthropic chat completion endpoint.
anthropicRouter.post( anthropicRouter.post(
"/v1/messages", "/v1/messages",
@ -230,23 +281,30 @@ anthropicRouter.post(
}), }),
anthropicProxy anthropicProxy
); );
// OpenAI-to-Anthropic Text compatibility endpoint. // Anthropic text completion endpoint. Translates to Anthropic chat completion
// if the requested model is a Claude 3 model.
anthropicRouter.post(
"/v1/complete",
ipLimiter,
preprocessAnthropicTextRequest,
anthropicProxy
);
// OpenAI-to-Anthropic compatibility endpoint. Accepts an OpenAI chat completion
// request and transforms/routes it to the appropriate Anthropic format and
// endpoint based on the requested model.
anthropicRouter.post( anthropicRouter.post(
"/v1/chat/completions", "/v1/chat/completions",
ipLimiter, ipLimiter,
createPreprocessorMiddleware( preprocessOpenAICompatRequest,
{ inApi: "openai", outApi: "anthropic-text", service: "anthropic" },
{ afterTransform: [maybeReassignModel] }
),
anthropicProxy anthropicProxy
); );
// Temporary force Anthropic Text to Anthropic Chat for frontends which do not // Temporarily force Anthropic Text to Anthropic Chat for frontends which do not
// yet support the new model. Forces claude-3. Will be removed once common // yet support the new model. Forces claude-3. Will be removed once common
// frontends have been updated. // frontends have been updated.
anthropicRouter.post( anthropicRouter.post(
"/v1/:type(sonnet|opus)/:action(complete|messages)", "/v1/:type(sonnet|opus)/:action(complete|messages)",
ipLimiter, ipLimiter,
handleCompatibilityRequest, handleAnthropicTextCompatRequest,
createPreprocessorMiddleware({ createPreprocessorMiddleware({
inApi: "anthropic-text", inApi: "anthropic-text",
outApi: "anthropic-chat", outApi: "anthropic-chat",
@ -255,7 +313,11 @@ anthropicRouter.post(
anthropicProxy anthropicProxy
); );
function handleCompatibilityRequest(req: Request, res: Response, next: any) { function handleAnthropicTextCompatRequest(
req: Request,
res: Response,
next: any
) {
const type = req.params.type; const type = req.params.type;
const action = req.params.action; const action = req.params.action;
const alreadyInChatFormat = Boolean(req.body.messages); const alreadyInChatFormat = Boolean(req.body.messages);
@ -287,10 +349,14 @@ function handleCompatibilityRequest(req: Request, res: Response, next: any) {
next(); next();
} }
/**
* If a client using the OpenAI compatibility endpoint requests an actual OpenAI
* model, reassigns it to Claude 3 Sonnet.
*/
function maybeReassignModel(req: Request) { function maybeReassignModel(req: Request) {
const model = req.body.model; const model = req.body.model;
if (!model.startsWith("gpt-")) return; if (!model.startsWith("gpt-")) return;
req.body.model = "claude-2.1"; req.body.model = "claude-3-sonnet-20240229";
} }
export const anthropic = anthropicRouter; export const anthropic = anthropicRouter;

View File

@ -5,7 +5,7 @@ import { HttpRequest } from "@smithy/protocol-http";
import { import {
AnthropicV1TextSchema, AnthropicV1TextSchema,
AnthropicV1MessagesSchema, AnthropicV1MessagesSchema,
} from "../../../../shared/api-schemas/anthropic"; } from "../../../../shared/api-schemas";
import { keyPool } from "../../../../shared/key-management"; import { keyPool } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index"; import { RequestPreprocessor } from "../index";

View File

@ -1,12 +1,9 @@
import { import {
anthropicTextToAnthropicChat, API_REQUEST_VALIDATORS,
openAIToAnthropicText, API_REQUEST_TRANSFORMERS,
} from "../../../../shared/api-schemas/anthropic"; } from "../../../../shared/api-schemas";
import { openAIToOpenAIText } from "../../../../shared/api-schemas/openai-text"; import { BadRequestError } from "../../../../shared/errors";
import { openAIToOpenAIImage } from "../../../../shared/api-schemas/openai-image";
import { openAIToGoogleAI } from "../../../../shared/api-schemas/google-ai";
import { fixMistralPrompt } from "../../../../shared/api-schemas/mistral-ai"; import { fixMistralPrompt } from "../../../../shared/api-schemas/mistral-ai";
import { API_SCHEMA_VALIDATORS } from "../../../../shared/api-schemas";
import { import {
isImageGenerationRequest, isImageGenerationRequest,
isTextGenerationRequest, isTextGenerationRequest,
@ -22,6 +19,7 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
if (alreadyTransformed || notTransformable) return; if (alreadyTransformed || notTransformable) return;
// TODO: this should be an APIFormatTransformer
if (req.inboundApi === "mistral-ai") { if (req.inboundApi === "mistral-ai") {
const messages = req.body.messages; const messages = req.body.messages;
req.body.messages = fixMistralPrompt(messages); req.body.messages = fixMistralPrompt(messages);
@ -32,9 +30,9 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
} }
if (sameService) { if (sameService) {
const result = API_SCHEMA_VALIDATORS[req.inboundApi].safeParse(req.body); const result = API_REQUEST_VALIDATORS[req.inboundApi].safeParse(req.body);
if (!result.success) { if (!result.success) {
req.log.error( req.log.warn(
{ issues: result.error.issues, body: req.body }, { issues: result.error.issues, body: req.body },
"Request validation failed" "Request validation failed"
); );
@ -44,35 +42,16 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
return; return;
} }
if ( const transformation = `${req.inboundApi}->${req.outboundApi}` as const;
req.inboundApi === "anthropic-text" && const transFn = API_REQUEST_TRANSFORMERS[transformation];
req.outboundApi === "anthropic-chat"
) { if (transFn) {
req.body = anthropicTextToAnthropicChat(req); req.log.info({ transformation }, "Transforming request");
req.body = await transFn(req);
return; return;
} }
if (req.inboundApi === "openai" && req.outboundApi === "anthropic-text") { throw new BadRequestError(
req.body = openAIToAnthropicText(req); `${transformation} proxying is not supported. Make sure your client is configured to send requests in the correct format and to the correct endpoint.`
return;
}
if (req.inboundApi === "openai" && req.outboundApi === "google-ai") {
req.body = openAIToGoogleAI(req);
return;
}
if (req.inboundApi === "openai" && req.outboundApi === "openai-text") {
req.body = openAIToOpenAIText(req);
return;
}
if (req.inboundApi === "openai" && req.outboundApi === "openai-image") {
req.body = openAIToOpenAIImage(req);
return;
}
throw new Error(
`'${req.inboundApi}' -> '${req.outboundApi}' request proxying is not supported. Make sure your client is configured to use the correct API.`
); );
}; };

View File

@ -39,6 +39,7 @@ export { openAITextToOpenAIChat } from "./transformers/openai-text-to-openai";
export { anthropicV1ToOpenAI } from "./transformers/anthropic-v1-to-openai"; export { anthropicV1ToOpenAI } from "./transformers/anthropic-v1-to-openai";
export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai"; export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai";
export { anthropicChatToAnthropicV2 } from "./transformers/anthropic-chat-to-anthropic-v2"; export { anthropicChatToAnthropicV2 } from "./transformers/anthropic-chat-to-anthropic-v2";
export { anthropicChatToOpenAI } from "./transformers/anthropic-chat-to-openai";
export { googleAIToOpenAI } from "./transformers/google-ai-to-openai"; export { googleAIToOpenAI } from "./transformers/google-ai-to-openai";
export { passthroughToOpenAI } from "./transformers/passthrough-to-openai"; export { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat"; export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat";

View File

@ -3,6 +3,7 @@ import { logger } from "../../../../logger";
import { APIFormat } from "../../../../shared/key-management"; import { APIFormat } from "../../../../shared/key-management";
import { assertNever } from "../../../../shared/utils"; import { assertNever } from "../../../../shared/utils";
import { import {
anthropicChatToOpenAI,
anthropicChatToAnthropicV2, anthropicChatToAnthropicV2,
anthropicV1ToOpenAI, anthropicV1ToOpenAI,
AnthropicV2StreamEvent, AnthropicV2StreamEvent,
@ -117,7 +118,11 @@ function eventIsOpenAIEvent(
function getTransformer( function getTransformer(
responseApi: APIFormat, responseApi: APIFormat,
version?: string version?: string,
// There's only one case where we're not transforming back to OpenAI, which is
// Anthropic Chat response -> Anthropic Text request. This parameter is only
// used for that case.
requestApi: APIFormat = "openai"
): StreamingCompletionTransformer< ): StreamingCompletionTransformer<
OpenAIChatCompletionStreamEvent | AnthropicV2StreamEvent OpenAIChatCompletionStreamEvent | AnthropicV2StreamEvent
> { > {
@ -132,7 +137,9 @@ function getTransformer(
? anthropicV1ToOpenAI ? anthropicV1ToOpenAI
: anthropicV2ToOpenAI; : anthropicV2ToOpenAI;
case "anthropic-chat": case "anthropic-chat":
return anthropicChatToAnthropicV2; return requestApi === "anthropic-text"
? anthropicChatToAnthropicV2
: anthropicChatToOpenAI;
case "google-ai": case "google-ai":
return googleAIToOpenAI; return googleAIToOpenAI;
case "openai-image": case "openai-image":

View File

@ -1,11 +1,11 @@
import { z } from "zod"; import { z } from "zod";
import { Request } from "express";
import { config } from "../../config"; import { config } from "../../config";
import { import {
flattenOpenAIMessageContent, flattenOpenAIMessageContent,
OpenAIChatMessage, OpenAIChatMessage,
OpenAIV1ChatCompletionSchema, OpenAIV1ChatCompletionSchema,
} from "./openai"; } from "./openai";
import { APIFormatTransformer } from "./index";
const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic; const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic;
@ -69,9 +69,7 @@ export type AnthropicChatMessage = z.infer<
typeof AnthropicV1MessagesSchema typeof AnthropicV1MessagesSchema
>["messages"][0]; >["messages"][0];
export function openAIMessagesToClaudeTextPrompt( function openAIMessagesToClaudeTextPrompt(messages: OpenAIChatMessage[]) {
messages: OpenAIChatMessage[]
) {
return ( return (
messages messages
.map((m) => { .map((m) => {
@ -93,7 +91,44 @@ export function openAIMessagesToClaudeTextPrompt(
); );
} }
export function openAIToAnthropicText(req: Request) { export const transformOpenAIToAnthropicChat: APIFormatTransformer<
typeof AnthropicV1MessagesSchema
> = async (req) => {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid OpenAI-to-Anthropic Chat request"
);
throw result.error;
}
req.headers["anthropic-version"] = "2023-06-01";
const { messages, ...rest } = result.data;
const { messages: newMessages, system } =
openAIMessagesToClaudeChatPrompt(messages);
return {
system,
messages: newMessages,
model: rest.model,
max_tokens: rest.max_tokens,
stream: rest.stream,
temperature: rest.temperature,
top_p: rest.top_p,
stop_sequences: typeof rest.stop === "string" ? [rest.stop] : rest.stop,
...(rest.user ? { metadata: { user_id: rest.user } } : {}),
// Anthropic supports top_k, but OpenAI does not
// OpenAI supports frequency_penalty, presence_penalty, logit_bias, n, seed,
// and function calls, but Anthropic does not.
};
};
export const transformOpenAIToAnthropicText: APIFormatTransformer<
typeof AnthropicV1TextSchema
> = async (req) => {
const { body } = req; const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse(body); const result = OpenAIV1ChatCompletionSchema.safeParse(body);
if (!result.success) { if (!result.success) {
@ -131,13 +166,15 @@ export function openAIToAnthropicText(req: Request) {
temperature: rest.temperature, temperature: rest.temperature,
top_p: rest.top_p, top_p: rest.top_p,
}; };
} };
/** /**
* Converts an older Anthropic Text Completion prompt to the newer Messages API * Converts an older Anthropic Text Completion prompt to the newer Messages API
* by splitting the flat text into messages. * by splitting the flat text into messages.
*/ */
export function anthropicTextToAnthropicChat(req: Request) { export const transformAnthropicTextToAnthropicChat: APIFormatTransformer<
typeof AnthropicV1MessagesSchema
> = async (req) => {
const { body } = req; const { body } = req;
const result = AnthropicV1TextSchema.safeParse(body); const result = AnthropicV1TextSchema.safeParse(body);
if (!result.success) { if (!result.success) {
@ -163,8 +200,8 @@ export function anthropicTextToAnthropicChat(req: Request) {
while (remaining) { while (remaining) {
const isHuman = remaining.startsWith("\n\nHuman:"); const isHuman = remaining.startsWith("\n\nHuman:");
// TODO: Are multiple consecutive human or assistant messages allowed? // Multiple messages from the same role are not permitted in Messages API.
// Currently we will enforce alternating turns. // We collect all messages until the next message from the opposite role.
const thisRole = isHuman ? "\n\nHuman:" : "\n\nAssistant:"; const thisRole = isHuman ? "\n\nHuman:" : "\n\nAssistant:";
const nextRole = isHuman ? "\n\nAssistant:" : "\n\nHuman:"; const nextRole = isHuman ? "\n\nAssistant:" : "\n\nHuman:";
const nextIndex = remaining.indexOf(nextRole); const nextIndex = remaining.indexOf(nextRole);
@ -199,7 +236,7 @@ export function anthropicTextToAnthropicChat(req: Request) {
max_tokens: max_tokens_to_sample, max_tokens: max_tokens_to_sample,
...rest, ...rest,
}; };
} };
function validateAnthropicTextPrompt(prompt: string) { function validateAnthropicTextPrompt(prompt: string) {
if (!prompt.includes("\n\nHuman:") || !prompt.includes("\n\nAssistant:")) { if (!prompt.includes("\n\nHuman:") || !prompt.includes("\n\nAssistant:")) {
@ -236,3 +273,167 @@ export function flattenAnthropicMessages(
}) })
.join("\n\n"); .join("\n\n");
} }
/**
* Represents the union of all content types without the `string` shorthand
* for `text` content.
*/
type AnthropicChatMessageContentWithoutString = Exclude<
AnthropicChatMessage["content"],
string
>;
/** Represents a message with all shorthand `string` content expanded. */
type ConvertedAnthropicChatMessage = AnthropicChatMessage & {
content: AnthropicChatMessageContentWithoutString;
};
function openAIMessagesToClaudeChatPrompt(messages: OpenAIChatMessage[]): {
messages: AnthropicChatMessage[];
system: string;
} {
// Similar formats, but Claude doesn't use `name` property and doesn't have
// a `system` role. Also, Claude does not allow consecutive messages from
// the same role, so we need to merge them.
// 1. Collect all system messages up to the first non-system message and set
// that as the `system` prompt.
// 2. Iterate through messages and:
// - If the message is from system, reassign it to assistant with System:
// prefix.
// - If message is from same role as previous, append it to the previous
// message rather than creating a new one.
// - Otherwise, create a new message and prefix with `name` if present.
// TODO: When a Claude message has multiple `text` contents, does the internal
// message flattening insert newlines between them? If not, we may need to
// do that here...
let firstNonSystem = -1;
const result: { messages: ConvertedAnthropicChatMessage[]; system: string } =
{ messages: [], system: "" };
for (let i = 0; i < messages.length; i++) {
const msg = messages[i];
const isSystem = isSystemOpenAIRole(msg.role);
if (firstNonSystem === -1 && isSystem) {
// Still merging initial system messages into the system prompt
result.system += getFirstTextContent(msg.content) + "\n";
continue;
}
if (firstNonSystem === -1 && !isSystem) {
// Encountered the first non-system message
firstNonSystem = i;
if (msg.role === "assistant") {
// There is an annoying rule that the first message must be from the user.
// This is commonly not the case with roleplay prompts that start with a
// block of system messages followed by an assistant message. We will try
// to reconcile this by splicing the last line of the system prompt into
// a beginning user message -- this is *commonly* ST's [Start a new chat]
// nudge, which works okay as a user message.
// Find the last non-empty line in the system prompt
const execResult = /(?:[^\r\n]*\r?\n)*([^\r\n]+)(?:\r?\n)*/d.exec(
result.system
);
let text = "";
if (execResult) {
text = execResult[1];
// Remove last line from system so it doesn't get duplicated
const [_, [lastLineStart]] = execResult.indices || [];
result.system = result.system.slice(0, lastLineStart);
} else {
// This is a bad prompt; there's no system content to move to user and
// it starts with assistant. We don't have any good options.
text = "[ Joining chat... ]";
}
result.messages.push({
role: "user",
content: [{ type: "text", text }],
});
}
}
const last = result.messages[result.messages.length - 1];
// I have to handle tools as system messages to be exhaustive here but the
// experience will be bad.
const role = isSystemOpenAIRole(msg.role) ? "assistant" : msg.role;
// Here we will lose the original name if it was a system message, but that
// is generally okay because the system message is usually a prompt and not
// a character in the chat.
const name = msg.role === "system" ? "System" : msg.name?.trim();
const content = convertOpenAIContent(msg.content);
// Prepend the display name to the first text content in the current message
// if it exists. We don't need to add the name to every content block.
if (name?.length) {
const firstTextContent = content.find((c) => c.type === "text");
if (firstTextContent && "text" in firstTextContent) {
// This mutates the element in `content`.
firstTextContent.text = `${name}: ${firstTextContent.text}`;
}
}
// Merge messages if necessary. If two assistant roles are consecutive but
// had different names, the final converted assistant message will have
// multiple characters in it, but the name prefixes should assist the model
// in differentiating between speakers.
if (last && last.role === role) {
last.content.push(...content);
} else {
result.messages.push({ role, content });
}
}
result.system = result.system.trimEnd();
return result;
}
function isSystemOpenAIRole(
role: OpenAIChatMessage["role"]
): role is "system" | "function" | "tool" {
return ["system", "function", "tool"].includes(role);
}
function getFirstTextContent(content: OpenAIChatMessage["content"]) {
if (typeof content === "string") return content;
for (const c of content) {
if ("text" in c) return c.text;
}
return "[ No text content in this message ]";
}
function convertOpenAIContent(
content: OpenAIChatMessage["content"]
): AnthropicChatMessageContentWithoutString {
if (typeof content === "string") {
return [{ type: "text", text: content.trimEnd() }];
}
return content.map((c) => {
if ("text" in c) {
return { type: "text", text: c.text.trimEnd() };
} else if ("image_url" in c) {
const url = c.image_url.url;
try {
const mimeType = url.split(";")[0].split(":")[1];
const data = url.split(",")[1];
return {
type: "image",
source: { type: "base64", media_type: mimeType, data },
};
} catch (e) {
return {
type: "text",
text: `[ Unsupported image URL: ${url.slice(0, 200)} ]`,
};
}
} else {
const type = String((c as any)?.type);
return { type: "text", text: `[ Unsupported content type: ${type} ]` };
}
});
}

View File

@ -1,9 +1,9 @@
import { z } from "zod"; import { z } from "zod";
import { Request } from "express";
import { import {
flattenOpenAIMessageContent, flattenOpenAIMessageContent,
OpenAIV1ChatCompletionSchema, OpenAIV1ChatCompletionSchema,
} from "./openai"; } from "./openai";
import { APIFormatTransformer } from "./index";
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateContent // https://developers.generativeai.google/api/rest/generativelanguage/models/generateContent
export const GoogleAIV1GenerateContentSchema = z export const GoogleAIV1GenerateContentSchema = z
@ -14,7 +14,7 @@ export const GoogleAIV1GenerateContentSchema = z
z.object({ z.object({
parts: z.array(z.object({ text: z.string() })), parts: z.array(z.object({ text: z.string() })),
role: z.enum(["user", "model"]), role: z.enum(["user", "model"]),
}), })
), ),
tools: z.array(z.object({})).max(0).optional(), tools: z.array(z.object({})).max(0).optional(),
safetySettings: z.array(z.object({})).max(0).optional(), safetySettings: z.array(z.object({})).max(0).optional(),
@ -37,9 +37,9 @@ export type GoogleAIChatMessage = z.infer<
typeof GoogleAIV1GenerateContentSchema typeof GoogleAIV1GenerateContentSchema
>["contents"][0]; >["contents"][0];
export function openAIToGoogleAI( export const transformOpenAIToGoogleAI: APIFormatTransformer<
req: Request, typeof GoogleAIV1GenerateContentSchema
): z.infer<typeof GoogleAIV1GenerateContentSchema> { > = async (req) => {
const { body } = req; const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse({ const result = OpenAIV1ChatCompletionSchema.safeParse({
...body, ...body,
@ -48,7 +48,7 @@ export function openAIToGoogleAI(
if (!result.success) { if (!result.success) {
req.log.warn( req.log.warn(
{ issues: result.error.issues, body }, { issues: result.error.issues, body },
"Invalid OpenAI-to-Google AI request", "Invalid OpenAI-to-Google AI request"
); );
throw result.error; throw result.error;
} }
@ -121,4 +121,4 @@ export function openAIToGoogleAI(
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", threshold: "BLOCK_NONE" }, { category: "HARM_CATEGORY_DANGEROUS_CONTENT", threshold: "BLOCK_NONE" },
], ],
}; };
} };

View File

@ -1,18 +1,57 @@
import type { Request } from "express";
import { z } from "zod"; import { z } from "zod";
import { APIFormat } from "../key-management"; import { APIFormat } from "../key-management";
import { AnthropicV1TextSchema, AnthropicV1MessagesSchema } from "./anthropic"; import {
AnthropicV1TextSchema,
AnthropicV1MessagesSchema,
transformAnthropicTextToAnthropicChat,
transformOpenAIToAnthropicText,
transformOpenAIToAnthropicChat,
} from "./anthropic";
import { OpenAIV1ChatCompletionSchema } from "./openai"; import { OpenAIV1ChatCompletionSchema } from "./openai";
import { OpenAIV1TextCompletionSchema } from "./openai-text"; import {
import { OpenAIV1ImagesGenerationSchema } from "./openai-image"; OpenAIV1TextCompletionSchema,
import { GoogleAIV1GenerateContentSchema } from "./google-ai"; transformOpenAIToOpenAIText,
} from "./openai-text";
import {
OpenAIV1ImagesGenerationSchema,
transformOpenAIToOpenAIImage,
} from "./openai-image";
import {
GoogleAIV1GenerateContentSchema,
transformOpenAIToGoogleAI,
} from "./google-ai";
import { MistralAIV1ChatCompletionsSchema } from "./mistral-ai"; import { MistralAIV1ChatCompletionsSchema } from "./mistral-ai";
export { OpenAIChatMessage } from "./openai"; export { OpenAIChatMessage } from "./openai";
export { AnthropicChatMessage, flattenAnthropicMessages } from "./anthropic"; export {
AnthropicChatMessage,
AnthropicV1TextSchema,
AnthropicV1MessagesSchema,
flattenAnthropicMessages,
} from "./anthropic";
export { GoogleAIChatMessage } from "./google-ai"; export { GoogleAIChatMessage } from "./google-ai";
export { MistralAIChatMessage } from "./mistral-ai"; export { MistralAIChatMessage } from "./mistral-ai";
export const API_SCHEMA_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = { type APIPair = `${APIFormat}->${APIFormat}`;
type TransformerMap = {
[key in APIPair]?: APIFormatTransformer<any>;
};
export type APIFormatTransformer<Z extends z.ZodType<any, any>> = (
req: Request
) => Promise<z.infer<Z>>;
export const API_REQUEST_TRANSFORMERS: TransformerMap = {
"anthropic-text->anthropic-chat": transformAnthropicTextToAnthropicChat,
"openai->anthropic-chat": transformOpenAIToAnthropicChat,
"openai->anthropic-text": transformOpenAIToAnthropicText,
"openai->openai-text": transformOpenAIToOpenAIText,
"openai->openai-image": transformOpenAIToOpenAIImage,
"openai->google-ai": transformOpenAIToGoogleAI,
};
export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
"anthropic-chat": AnthropicV1MessagesSchema, "anthropic-chat": AnthropicV1MessagesSchema,
"anthropic-text": AnthropicV1TextSchema, "anthropic-text": AnthropicV1TextSchema,
openai: OpenAIV1ChatCompletionSchema, openai: OpenAIV1ChatCompletionSchema,

View File

@ -1,6 +1,6 @@
import { z } from "zod"; import { z } from "zod";
import { Request } from "express";
import { OpenAIV1ChatCompletionSchema } from "./openai"; import { OpenAIV1ChatCompletionSchema } from "./openai";
import { APIFormatTransformer } from "./index";
// https://platform.openai.com/docs/api-reference/images/create // https://platform.openai.com/docs/api-reference/images/create
export const OpenAIV1ImagesGenerationSchema = z export const OpenAIV1ImagesGenerationSchema = z
@ -20,47 +20,49 @@ export const OpenAIV1ImagesGenerationSchema = z
.strip(); .strip();
// Takes the last chat message and uses it verbatim as the image prompt. // Takes the last chat message and uses it verbatim as the image prompt.
export function openAIToOpenAIImage(req: Request) { export const transformOpenAIToOpenAIImage: APIFormatTransformer<
const { body } = req; typeof OpenAIV1ImagesGenerationSchema
const result = OpenAIV1ChatCompletionSchema.safeParse(body); > = async (req) => {
if (!result.success) { const { body } = req;
req.log.warn( const result = OpenAIV1ChatCompletionSchema.safeParse(body);
{ issues: result.error.issues, body }, if (!result.success) {
"Invalid OpenAI-to-OpenAI-image request", req.log.warn(
); { issues: result.error.issues, body },
throw result.error; "Invalid OpenAI-to-OpenAI-image request"
} );
throw result.error;
}
const { messages } = result.data; const { messages } = result.data;
const prompt = messages.filter((m) => m.role === "user").pop()?.content; const prompt = messages.filter((m) => m.role === "user").pop()?.content;
if (Array.isArray(prompt)) { if (Array.isArray(prompt)) {
throw new Error("Image generation prompt must be a text message."); throw new Error("Image generation prompt must be a text message.");
} }
if (body.stream) { if (body.stream) {
throw new Error( throw new Error(
"Streaming is not supported for image generation requests.", "Streaming is not supported for image generation requests."
); );
} }
// Some frontends do weird things with the prompt, like prefixing it with a // Some frontends do weird things with the prompt, like prefixing it with a
// character name or wrapping the entire thing in quotes. We will look for // character name or wrapping the entire thing in quotes. We will look for
// the index of "Image:" and use everything after that as the prompt. // the index of "Image:" and use everything after that as the prompt.
const index = prompt?.toLowerCase().indexOf("image:"); const index = prompt?.toLowerCase().indexOf("image:");
if (index === -1 || !prompt) { if (index === -1 || !prompt) {
throw new Error( throw new Error(
`Start your prompt with 'Image:' followed by a description of the image you want to generate (received: ${prompt}).`, `Start your prompt with 'Image:' followed by a description of the image you want to generate (received: ${prompt}).`
); );
} }
// TODO: Add some way to specify parameters via chat message // TODO: Add some way to specify parameters via chat message
const transformed = { const transformed = {
model: body.model.includes("dall-e") ? body.model : "dall-e-3", model: body.model.includes("dall-e") ? body.model : "dall-e-3",
quality: "standard", quality: "standard",
size: "1024x1024", size: "1024x1024",
response_format: "url", response_format: "url",
prompt: prompt.slice(index! + 6).trim(), prompt: prompt.slice(index! + 6).trim(),
}; };
return OpenAIV1ImagesGenerationSchema.parse(transformed); return OpenAIV1ImagesGenerationSchema.parse(transformed);
} };

View File

@ -3,7 +3,7 @@ import {
flattenOpenAIChatMessages, flattenOpenAIChatMessages,
OpenAIV1ChatCompletionSchema, OpenAIV1ChatCompletionSchema,
} from "./openai"; } from "./openai";
import { Request } from "express"; import { APIFormatTransformer } from "./index";
export const OpenAIV1TextCompletionSchema = z export const OpenAIV1TextCompletionSchema = z
.object({ .object({
@ -29,7 +29,9 @@ export const OpenAIV1TextCompletionSchema = z
.strip() .strip()
.merge(OpenAIV1ChatCompletionSchema.omit({ messages: true, logprobs: true })); .merge(OpenAIV1ChatCompletionSchema.omit({ messages: true, logprobs: true }));
export function openAIToOpenAIText(req: Request) { export const transformOpenAIToOpenAIText: APIFormatTransformer<
typeof OpenAIV1TextCompletionSchema
> = async (req) => {
const { body } = req; const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse(body); const result = OpenAIV1ChatCompletionSchema.safeParse(body);
if (!result.success) { if (!result.success) {
@ -53,4 +55,4 @@ export function openAIToOpenAIText(req: Request) {
const transformed = { ...rest, prompt: prompt, stop: stops }; const transformed = { ...rest, prompt: prompt, stop: stops };
return OpenAIV1TextCompletionSchema.parse(transformed); return OpenAIV1TextCompletionSchema.parse(transformed);
} };

View File

@ -338,12 +338,13 @@ function refreshAllQuotas() {
// store to sync it with Firebase when it changes. Will refactor to abstract // store to sync it with Firebase when it changes. Will refactor to abstract
// persistence layer later so we can support multiple stores. // persistence layer later so we can support multiple stores.
let firebaseTimeout: NodeJS.Timeout | undefined; let firebaseTimeout: NodeJS.Timeout | undefined;
const USERS_REF = process.env.FIREBASE_USERS_REF_NAME ?? "users";
async function initFirebase() { async function initFirebase() {
log.info("Connecting to Firebase..."); log.info("Connecting to Firebase...");
const app = getFirebaseApp(); const app = getFirebaseApp();
const db = admin.database(app); const db = admin.database(app);
const usersRef = db.ref("users"); const usersRef = db.ref(USERS_REF);
const snapshot = await usersRef.once("value"); const snapshot = await usersRef.once("value");
const users: Record<string, User> | null = snapshot.val(); const users: Record<string, User> | null = snapshot.val();
firebaseTimeout = setInterval(flushUsers, 20 * 1000); firebaseTimeout = setInterval(flushUsers, 20 * 1000);
@ -362,7 +363,7 @@ async function initFirebase() {
async function flushUsers() { async function flushUsers() {
const app = getFirebaseApp(); const app = getFirebaseApp();
const db = admin.database(app); const db = admin.database(app);
const usersRef = db.ref("users"); const usersRef = db.ref(USERS_REF);
const updates: Record<string, User> = {}; const updates: Record<string, User> = {};
const deletions = []; const deletions = [];

View File

@ -1,7 +1,7 @@
{ {
"compilerOptions": { "compilerOptions": {
"strict": true, "strict": true,
"target": "ES2020", "target": "ES2022",
"module": "CommonJS", "module": "CommonJS",
"moduleResolution": "node", "moduleResolution": "node",
"esModuleInterop": true, "esModuleInterop": true,