diff --git a/src/proxy/anthropic.ts b/src/proxy/anthropic.ts index 1088fe7..12c30e6 100644 --- a/src/proxy/anthropic.ts +++ b/src/proxy/anthropic.ts @@ -83,17 +83,19 @@ const anthropicResponseHandler: ProxyResHandlerWithBody = async ( body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`; } - if (req.inboundApi === "openai") { - req.log.info("Transforming Anthropic text to OpenAI format"); - body = transformAnthropicTextResponseToOpenAI(body, req); - } - - if ( - req.inboundApi === "anthropic-text" && - req.outboundApi === "anthropic-chat" - ) { - req.log.info("Transforming Anthropic text to Anthropic chat format"); - body = transformAnthropicChatResponseToAnthropicText(body); + switch (`${req.inboundApi}<-${req.outboundApi}`) { + case "openai<-anthropic-text": + req.log.info("Transforming Anthropic Text back to OpenAI format"); + body = transformAnthropicTextResponseToOpenAI(body, req); + break; + case "openai<-anthropic-chat": + req.log.info("Transforming Anthropic Chat back to OpenAI format"); + body = transformAnthropicChatResponseToOpenAI(body); + break; + case "anthropic-text<-anthropic-chat": + req.log.info("Transforming Anthropic Chat back to Anthropic chat format"); + body = transformAnthropicChatResponseToAnthropicText(body); + break; } if (req.tokenizerInfo) { @@ -103,17 +105,23 @@ const anthropicResponseHandler: ProxyResHandlerWithBody = async ( res.status(200).json(body); }; +function flattenChatResponse( + content: { type: string; text: string }[] +): string { + return content + .map((part: { type: string; text: string }) => + part.type === "text" ? part.text : "" + ) + .join("\n"); +} + export function transformAnthropicChatResponseToAnthropicText( anthropicBody: Record ): Record { return { type: "completion", - id: "trans-" + anthropicBody.id, - completion: anthropicBody.content - .map((part: { type: string; text: string }) => - part.type === "text" ? part.text : "" - ) - .join(""), + id: "ant-" + anthropicBody.id, + completion: flattenChatResponse(anthropicBody.content), stop_reason: anthropicBody.stop_reason, stop: anthropicBody.stop_sequence, model: anthropicBody.model, @@ -155,6 +163,28 @@ function transformAnthropicTextResponseToOpenAI( }; } +function transformAnthropicChatResponseToOpenAI( + anthropicBody: Record +): Record { + return { + id: "ant-" + anthropicBody.id, + object: "chat.completion", + created: Date.now(), + model: anthropicBody.model, + usage: anthropicBody.usage, + choices: [ + { + message: { + role: "assistant", + content: flattenChatResponse(anthropicBody.content), + }, + finish_reason: anthropicBody.stop_reason, + index: 0, + }, + ], + }; +} + const anthropicProxy = createQueueMiddleware({ proxyMiddleware: createProxyMiddleware({ target: "https://api.anthropic.com", @@ -178,6 +208,9 @@ const anthropicProxy = createQueueMiddleware({ if (isText && pathname === "/v1/chat/completions") { req.url = "/v1/complete"; } + if (isChat && pathname === "/v1/chat/completions") { + req.url = "/v1/messages"; + } if (isChat && ["sonnet", "opus"].includes(req.params.type)) { req.url = "/v1/messages"; } @@ -202,7 +235,7 @@ const textToChatPreprocessor = createPreprocessorMiddleware({ * Routes text completion prompts to anthropic-chat if they need translation * (claude-3 based models do not support the old text completion endpoint). */ -const claudeTextCompletionRouter: RequestHandler = (req, res, next) => { +const preprocessAnthropicTextRequest: RequestHandler = (req, res, next) => { if (req.body.model?.startsWith("claude-3")) { textToChatPreprocessor(req, res, next); } else { @@ -210,15 +243,33 @@ const claudeTextCompletionRouter: RequestHandler = (req, res, next) => { } }; +const oaiToTextPreprocessor = createPreprocessorMiddleware({ + inApi: "openai", + outApi: "anthropic-text", + service: "anthropic", +}); + +const oaiToChatPreprocessor = createPreprocessorMiddleware({ + inApi: "openai", + outApi: "anthropic-chat", + service: "anthropic", +}); + +/** + * Routes an OpenAI prompt to either the legacy Claude text completion endpoint + * or the new Claude chat completion endpoint, based on the requested model. + */ +const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => { + maybeReassignModel(req); + if (req.body.model?.includes("claude-3")) { + oaiToChatPreprocessor(req, res, next); + } else { + oaiToTextPreprocessor(req, res, next); + } +}; + const anthropicRouter = Router(); anthropicRouter.get("/v1/models", handleModelRequest); -// Anthropic text completion endpoint. Dynamic routing based on model. -anthropicRouter.post( - "/v1/complete", - ipLimiter, - claudeTextCompletionRouter, - anthropicProxy -); // Native Anthropic chat completion endpoint. anthropicRouter.post( "/v1/messages", @@ -230,23 +281,30 @@ anthropicRouter.post( }), anthropicProxy ); -// OpenAI-to-Anthropic Text compatibility endpoint. +// Anthropic text completion endpoint. Translates to Anthropic chat completion +// if the requested model is a Claude 3 model. +anthropicRouter.post( + "/v1/complete", + ipLimiter, + preprocessAnthropicTextRequest, + anthropicProxy +); +// OpenAI-to-Anthropic compatibility endpoint. Accepts an OpenAI chat completion +// request and transforms/routes it to the appropriate Anthropic format and +// endpoint based on the requested model. anthropicRouter.post( "/v1/chat/completions", ipLimiter, - createPreprocessorMiddleware( - { inApi: "openai", outApi: "anthropic-text", service: "anthropic" }, - { afterTransform: [maybeReassignModel] } - ), + preprocessOpenAICompatRequest, anthropicProxy ); -// Temporary force Anthropic Text to Anthropic Chat for frontends which do not +// Temporarily force Anthropic Text to Anthropic Chat for frontends which do not // yet support the new model. Forces claude-3. Will be removed once common // frontends have been updated. anthropicRouter.post( "/v1/:type(sonnet|opus)/:action(complete|messages)", ipLimiter, - handleCompatibilityRequest, + handleAnthropicTextCompatRequest, createPreprocessorMiddleware({ inApi: "anthropic-text", outApi: "anthropic-chat", @@ -255,7 +313,11 @@ anthropicRouter.post( anthropicProxy ); -function handleCompatibilityRequest(req: Request, res: Response, next: any) { +function handleAnthropicTextCompatRequest( + req: Request, + res: Response, + next: any +) { const type = req.params.type; const action = req.params.action; const alreadyInChatFormat = Boolean(req.body.messages); @@ -287,10 +349,14 @@ function handleCompatibilityRequest(req: Request, res: Response, next: any) { next(); } +/** + * If a client using the OpenAI compatibility endpoint requests an actual OpenAI + * model, reassigns it to Claude 3 Sonnet. + */ function maybeReassignModel(req: Request) { const model = req.body.model; if (!model.startsWith("gpt-")) return; - req.body.model = "claude-2.1"; + req.body.model = "claude-3-sonnet-20240229"; } export const anthropic = anthropicRouter; diff --git a/src/proxy/middleware/request/preprocessors/sign-aws-request.ts b/src/proxy/middleware/request/preprocessors/sign-aws-request.ts index 766f7d2..f15b32e 100644 --- a/src/proxy/middleware/request/preprocessors/sign-aws-request.ts +++ b/src/proxy/middleware/request/preprocessors/sign-aws-request.ts @@ -5,7 +5,7 @@ import { HttpRequest } from "@smithy/protocol-http"; import { AnthropicV1TextSchema, AnthropicV1MessagesSchema, -} from "../../../../shared/api-schemas/anthropic"; +} from "../../../../shared/api-schemas"; import { keyPool } from "../../../../shared/key-management"; import { RequestPreprocessor } from "../index"; diff --git a/src/proxy/middleware/request/preprocessors/transform-outbound-payload.ts b/src/proxy/middleware/request/preprocessors/transform-outbound-payload.ts index 755da26..1186367 100644 --- a/src/proxy/middleware/request/preprocessors/transform-outbound-payload.ts +++ b/src/proxy/middleware/request/preprocessors/transform-outbound-payload.ts @@ -1,12 +1,9 @@ import { - anthropicTextToAnthropicChat, - openAIToAnthropicText, -} from "../../../../shared/api-schemas/anthropic"; -import { openAIToOpenAIText } from "../../../../shared/api-schemas/openai-text"; -import { openAIToOpenAIImage } from "../../../../shared/api-schemas/openai-image"; -import { openAIToGoogleAI } from "../../../../shared/api-schemas/google-ai"; + API_REQUEST_VALIDATORS, + API_REQUEST_TRANSFORMERS, +} from "../../../../shared/api-schemas"; +import { BadRequestError } from "../../../../shared/errors"; import { fixMistralPrompt } from "../../../../shared/api-schemas/mistral-ai"; -import { API_SCHEMA_VALIDATORS } from "../../../../shared/api-schemas"; import { isImageGenerationRequest, isTextGenerationRequest, @@ -22,6 +19,7 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => { if (alreadyTransformed || notTransformable) return; + // TODO: this should be an APIFormatTransformer if (req.inboundApi === "mistral-ai") { const messages = req.body.messages; req.body.messages = fixMistralPrompt(messages); @@ -32,9 +30,9 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => { } if (sameService) { - const result = API_SCHEMA_VALIDATORS[req.inboundApi].safeParse(req.body); + const result = API_REQUEST_VALIDATORS[req.inboundApi].safeParse(req.body); if (!result.success) { - req.log.error( + req.log.warn( { issues: result.error.issues, body: req.body }, "Request validation failed" ); @@ -44,35 +42,16 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => { return; } - if ( - req.inboundApi === "anthropic-text" && - req.outboundApi === "anthropic-chat" - ) { - req.body = anthropicTextToAnthropicChat(req); + const transformation = `${req.inboundApi}->${req.outboundApi}` as const; + const transFn = API_REQUEST_TRANSFORMERS[transformation]; + + if (transFn) { + req.log.info({ transformation }, "Transforming request"); + req.body = await transFn(req); return; } - if (req.inboundApi === "openai" && req.outboundApi === "anthropic-text") { - req.body = openAIToAnthropicText(req); - return; - } - - if (req.inboundApi === "openai" && req.outboundApi === "google-ai") { - req.body = openAIToGoogleAI(req); - return; - } - - if (req.inboundApi === "openai" && req.outboundApi === "openai-text") { - req.body = openAIToOpenAIText(req); - return; - } - - if (req.inboundApi === "openai" && req.outboundApi === "openai-image") { - req.body = openAIToOpenAIImage(req); - return; - } - - throw new Error( - `'${req.inboundApi}' -> '${req.outboundApi}' request proxying is not supported. Make sure your client is configured to use the correct API.` + throw new BadRequestError( + `${transformation} proxying is not supported. Make sure your client is configured to send requests in the correct format and to the correct endpoint.` ); }; diff --git a/src/proxy/middleware/response/streaming/index.ts b/src/proxy/middleware/response/streaming/index.ts index 540b166..402c233 100644 --- a/src/proxy/middleware/response/streaming/index.ts +++ b/src/proxy/middleware/response/streaming/index.ts @@ -39,6 +39,7 @@ export { openAITextToOpenAIChat } from "./transformers/openai-text-to-openai"; export { anthropicV1ToOpenAI } from "./transformers/anthropic-v1-to-openai"; export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai"; export { anthropicChatToAnthropicV2 } from "./transformers/anthropic-chat-to-anthropic-v2"; +export { anthropicChatToOpenAI } from "./transformers/anthropic-chat-to-openai"; export { googleAIToOpenAI } from "./transformers/google-ai-to-openai"; export { passthroughToOpenAI } from "./transformers/passthrough-to-openai"; export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat"; diff --git a/src/proxy/middleware/response/streaming/sse-message-transformer.ts b/src/proxy/middleware/response/streaming/sse-message-transformer.ts index c0ffdce..90c0313 100644 --- a/src/proxy/middleware/response/streaming/sse-message-transformer.ts +++ b/src/proxy/middleware/response/streaming/sse-message-transformer.ts @@ -3,6 +3,7 @@ import { logger } from "../../../../logger"; import { APIFormat } from "../../../../shared/key-management"; import { assertNever } from "../../../../shared/utils"; import { + anthropicChatToOpenAI, anthropicChatToAnthropicV2, anthropicV1ToOpenAI, AnthropicV2StreamEvent, @@ -117,7 +118,11 @@ function eventIsOpenAIEvent( function getTransformer( responseApi: APIFormat, - version?: string + version?: string, + // There's only one case where we're not transforming back to OpenAI, which is + // Anthropic Chat response -> Anthropic Text request. This parameter is only + // used for that case. + requestApi: APIFormat = "openai" ): StreamingCompletionTransformer< OpenAIChatCompletionStreamEvent | AnthropicV2StreamEvent > { @@ -132,7 +137,9 @@ function getTransformer( ? anthropicV1ToOpenAI : anthropicV2ToOpenAI; case "anthropic-chat": - return anthropicChatToAnthropicV2; + return requestApi === "anthropic-text" + ? anthropicChatToAnthropicV2 + : anthropicChatToOpenAI; case "google-ai": return googleAIToOpenAI; case "openai-image": diff --git a/src/shared/api-schemas/anthropic.ts b/src/shared/api-schemas/anthropic.ts index dbe3534..dd56fc1 100644 --- a/src/shared/api-schemas/anthropic.ts +++ b/src/shared/api-schemas/anthropic.ts @@ -1,11 +1,11 @@ import { z } from "zod"; -import { Request } from "express"; import { config } from "../../config"; import { flattenOpenAIMessageContent, OpenAIChatMessage, OpenAIV1ChatCompletionSchema, } from "./openai"; +import { APIFormatTransformer } from "./index"; const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic; @@ -69,9 +69,7 @@ export type AnthropicChatMessage = z.infer< typeof AnthropicV1MessagesSchema >["messages"][0]; -export function openAIMessagesToClaudeTextPrompt( - messages: OpenAIChatMessage[] -) { +function openAIMessagesToClaudeTextPrompt(messages: OpenAIChatMessage[]) { return ( messages .map((m) => { @@ -93,7 +91,44 @@ export function openAIMessagesToClaudeTextPrompt( ); } -export function openAIToAnthropicText(req: Request) { +export const transformOpenAIToAnthropicChat: APIFormatTransformer< + typeof AnthropicV1MessagesSchema +> = async (req) => { + const { body } = req; + const result = OpenAIV1ChatCompletionSchema.safeParse(body); + if (!result.success) { + req.log.warn( + { issues: result.error.issues, body }, + "Invalid OpenAI-to-Anthropic Chat request" + ); + throw result.error; + } + + req.headers["anthropic-version"] = "2023-06-01"; + + const { messages, ...rest } = result.data; + const { messages: newMessages, system } = + openAIMessagesToClaudeChatPrompt(messages); + + return { + system, + messages: newMessages, + model: rest.model, + max_tokens: rest.max_tokens, + stream: rest.stream, + temperature: rest.temperature, + top_p: rest.top_p, + stop_sequences: typeof rest.stop === "string" ? [rest.stop] : rest.stop, + ...(rest.user ? { metadata: { user_id: rest.user } } : {}), + // Anthropic supports top_k, but OpenAI does not + // OpenAI supports frequency_penalty, presence_penalty, logit_bias, n, seed, + // and function calls, but Anthropic does not. + }; +}; + +export const transformOpenAIToAnthropicText: APIFormatTransformer< + typeof AnthropicV1TextSchema +> = async (req) => { const { body } = req; const result = OpenAIV1ChatCompletionSchema.safeParse(body); if (!result.success) { @@ -131,13 +166,15 @@ export function openAIToAnthropicText(req: Request) { temperature: rest.temperature, top_p: rest.top_p, }; -} +}; /** * Converts an older Anthropic Text Completion prompt to the newer Messages API * by splitting the flat text into messages. */ -export function anthropicTextToAnthropicChat(req: Request) { +export const transformAnthropicTextToAnthropicChat: APIFormatTransformer< + typeof AnthropicV1MessagesSchema +> = async (req) => { const { body } = req; const result = AnthropicV1TextSchema.safeParse(body); if (!result.success) { @@ -163,8 +200,8 @@ export function anthropicTextToAnthropicChat(req: Request) { while (remaining) { const isHuman = remaining.startsWith("\n\nHuman:"); - // TODO: Are multiple consecutive human or assistant messages allowed? - // Currently we will enforce alternating turns. + // Multiple messages from the same role are not permitted in Messages API. + // We collect all messages until the next message from the opposite role. const thisRole = isHuman ? "\n\nHuman:" : "\n\nAssistant:"; const nextRole = isHuman ? "\n\nAssistant:" : "\n\nHuman:"; const nextIndex = remaining.indexOf(nextRole); @@ -199,7 +236,7 @@ export function anthropicTextToAnthropicChat(req: Request) { max_tokens: max_tokens_to_sample, ...rest, }; -} +}; function validateAnthropicTextPrompt(prompt: string) { if (!prompt.includes("\n\nHuman:") || !prompt.includes("\n\nAssistant:")) { @@ -236,3 +273,167 @@ export function flattenAnthropicMessages( }) .join("\n\n"); } + +/** + * Represents the union of all content types without the `string` shorthand + * for `text` content. + */ +type AnthropicChatMessageContentWithoutString = Exclude< + AnthropicChatMessage["content"], + string +>; +/** Represents a message with all shorthand `string` content expanded. */ +type ConvertedAnthropicChatMessage = AnthropicChatMessage & { + content: AnthropicChatMessageContentWithoutString; +}; + +function openAIMessagesToClaudeChatPrompt(messages: OpenAIChatMessage[]): { + messages: AnthropicChatMessage[]; + system: string; +} { + // Similar formats, but Claude doesn't use `name` property and doesn't have + // a `system` role. Also, Claude does not allow consecutive messages from + // the same role, so we need to merge them. + // 1. Collect all system messages up to the first non-system message and set + // that as the `system` prompt. + // 2. Iterate through messages and: + // - If the message is from system, reassign it to assistant with System: + // prefix. + // - If message is from same role as previous, append it to the previous + // message rather than creating a new one. + // - Otherwise, create a new message and prefix with `name` if present. + + // TODO: When a Claude message has multiple `text` contents, does the internal + // message flattening insert newlines between them? If not, we may need to + // do that here... + + let firstNonSystem = -1; + const result: { messages: ConvertedAnthropicChatMessage[]; system: string } = + { messages: [], system: "" }; + for (let i = 0; i < messages.length; i++) { + const msg = messages[i]; + const isSystem = isSystemOpenAIRole(msg.role); + + if (firstNonSystem === -1 && isSystem) { + // Still merging initial system messages into the system prompt + result.system += getFirstTextContent(msg.content) + "\n"; + continue; + } + + if (firstNonSystem === -1 && !isSystem) { + // Encountered the first non-system message + firstNonSystem = i; + + if (msg.role === "assistant") { + // There is an annoying rule that the first message must be from the user. + // This is commonly not the case with roleplay prompts that start with a + // block of system messages followed by an assistant message. We will try + // to reconcile this by splicing the last line of the system prompt into + // a beginning user message -- this is *commonly* ST's [Start a new chat] + // nudge, which works okay as a user message. + + // Find the last non-empty line in the system prompt + const execResult = /(?:[^\r\n]*\r?\n)*([^\r\n]+)(?:\r?\n)*/d.exec( + result.system + ); + + let text = ""; + if (execResult) { + text = execResult[1]; + // Remove last line from system so it doesn't get duplicated + const [_, [lastLineStart]] = execResult.indices || []; + result.system = result.system.slice(0, lastLineStart); + } else { + // This is a bad prompt; there's no system content to move to user and + // it starts with assistant. We don't have any good options. + text = "[ Joining chat... ]"; + } + + result.messages.push({ + role: "user", + content: [{ type: "text", text }], + }); + } + } + + const last = result.messages[result.messages.length - 1]; + // I have to handle tools as system messages to be exhaustive here but the + // experience will be bad. + const role = isSystemOpenAIRole(msg.role) ? "assistant" : msg.role; + + // Here we will lose the original name if it was a system message, but that + // is generally okay because the system message is usually a prompt and not + // a character in the chat. + const name = msg.role === "system" ? "System" : msg.name?.trim(); + const content = convertOpenAIContent(msg.content); + + // Prepend the display name to the first text content in the current message + // if it exists. We don't need to add the name to every content block. + if (name?.length) { + const firstTextContent = content.find((c) => c.type === "text"); + if (firstTextContent && "text" in firstTextContent) { + // This mutates the element in `content`. + firstTextContent.text = `${name}: ${firstTextContent.text}`; + } + } + + // Merge messages if necessary. If two assistant roles are consecutive but + // had different names, the final converted assistant message will have + // multiple characters in it, but the name prefixes should assist the model + // in differentiating between speakers. + if (last && last.role === role) { + last.content.push(...content); + } else { + result.messages.push({ role, content }); + } + } + + result.system = result.system.trimEnd(); + return result; +} + +function isSystemOpenAIRole( + role: OpenAIChatMessage["role"] +): role is "system" | "function" | "tool" { + return ["system", "function", "tool"].includes(role); +} + +function getFirstTextContent(content: OpenAIChatMessage["content"]) { + if (typeof content === "string") return content; + for (const c of content) { + if ("text" in c) return c.text; + } + return "[ No text content in this message ]"; +} + +function convertOpenAIContent( + content: OpenAIChatMessage["content"] +): AnthropicChatMessageContentWithoutString { + if (typeof content === "string") { + return [{ type: "text", text: content.trimEnd() }]; + } + + return content.map((c) => { + if ("text" in c) { + return { type: "text", text: c.text.trimEnd() }; + } else if ("image_url" in c) { + const url = c.image_url.url; + try { + const mimeType = url.split(";")[0].split(":")[1]; + const data = url.split(",")[1]; + return { + type: "image", + source: { type: "base64", media_type: mimeType, data }, + }; + } catch (e) { + return { + type: "text", + text: `[ Unsupported image URL: ${url.slice(0, 200)} ]`, + }; + } + } else { + const type = String((c as any)?.type); + return { type: "text", text: `[ Unsupported content type: ${type} ]` }; + } + }); +} diff --git a/src/shared/api-schemas/google-ai.ts b/src/shared/api-schemas/google-ai.ts index bd525b7..62239ad 100644 --- a/src/shared/api-schemas/google-ai.ts +++ b/src/shared/api-schemas/google-ai.ts @@ -1,9 +1,9 @@ import { z } from "zod"; -import { Request } from "express"; import { flattenOpenAIMessageContent, OpenAIV1ChatCompletionSchema, } from "./openai"; +import { APIFormatTransformer } from "./index"; // https://developers.generativeai.google/api/rest/generativelanguage/models/generateContent export const GoogleAIV1GenerateContentSchema = z @@ -14,7 +14,7 @@ export const GoogleAIV1GenerateContentSchema = z z.object({ parts: z.array(z.object({ text: z.string() })), role: z.enum(["user", "model"]), - }), + }) ), tools: z.array(z.object({})).max(0).optional(), safetySettings: z.array(z.object({})).max(0).optional(), @@ -37,9 +37,9 @@ export type GoogleAIChatMessage = z.infer< typeof GoogleAIV1GenerateContentSchema >["contents"][0]; -export function openAIToGoogleAI( - req: Request, -): z.infer { +export const transformOpenAIToGoogleAI: APIFormatTransformer< + typeof GoogleAIV1GenerateContentSchema +> = async (req) => { const { body } = req; const result = OpenAIV1ChatCompletionSchema.safeParse({ ...body, @@ -48,7 +48,7 @@ export function openAIToGoogleAI( if (!result.success) { req.log.warn( { issues: result.error.issues, body }, - "Invalid OpenAI-to-Google AI request", + "Invalid OpenAI-to-Google AI request" ); throw result.error; } @@ -121,4 +121,4 @@ export function openAIToGoogleAI( { category: "HARM_CATEGORY_DANGEROUS_CONTENT", threshold: "BLOCK_NONE" }, ], }; -} +}; diff --git a/src/shared/api-schemas/index.ts b/src/shared/api-schemas/index.ts index 139dbc7..598bf23 100644 --- a/src/shared/api-schemas/index.ts +++ b/src/shared/api-schemas/index.ts @@ -1,18 +1,57 @@ +import type { Request } from "express"; import { z } from "zod"; import { APIFormat } from "../key-management"; -import { AnthropicV1TextSchema, AnthropicV1MessagesSchema } from "./anthropic"; +import { + AnthropicV1TextSchema, + AnthropicV1MessagesSchema, + transformAnthropicTextToAnthropicChat, + transformOpenAIToAnthropicText, + transformOpenAIToAnthropicChat, +} from "./anthropic"; import { OpenAIV1ChatCompletionSchema } from "./openai"; -import { OpenAIV1TextCompletionSchema } from "./openai-text"; -import { OpenAIV1ImagesGenerationSchema } from "./openai-image"; -import { GoogleAIV1GenerateContentSchema } from "./google-ai"; +import { + OpenAIV1TextCompletionSchema, + transformOpenAIToOpenAIText, +} from "./openai-text"; +import { + OpenAIV1ImagesGenerationSchema, + transformOpenAIToOpenAIImage, +} from "./openai-image"; +import { + GoogleAIV1GenerateContentSchema, + transformOpenAIToGoogleAI, +} from "./google-ai"; import { MistralAIV1ChatCompletionsSchema } from "./mistral-ai"; export { OpenAIChatMessage } from "./openai"; -export { AnthropicChatMessage, flattenAnthropicMessages } from "./anthropic"; +export { + AnthropicChatMessage, + AnthropicV1TextSchema, + AnthropicV1MessagesSchema, + flattenAnthropicMessages, +} from "./anthropic"; export { GoogleAIChatMessage } from "./google-ai"; export { MistralAIChatMessage } from "./mistral-ai"; -export const API_SCHEMA_VALIDATORS: Record> = { +type APIPair = `${APIFormat}->${APIFormat}`; +type TransformerMap = { + [key in APIPair]?: APIFormatTransformer; +}; + +export type APIFormatTransformer> = ( + req: Request +) => Promise>; + +export const API_REQUEST_TRANSFORMERS: TransformerMap = { + "anthropic-text->anthropic-chat": transformAnthropicTextToAnthropicChat, + "openai->anthropic-chat": transformOpenAIToAnthropicChat, + "openai->anthropic-text": transformOpenAIToAnthropicText, + "openai->openai-text": transformOpenAIToOpenAIText, + "openai->openai-image": transformOpenAIToOpenAIImage, + "openai->google-ai": transformOpenAIToGoogleAI, +}; + +export const API_REQUEST_VALIDATORS: Record> = { "anthropic-chat": AnthropicV1MessagesSchema, "anthropic-text": AnthropicV1TextSchema, openai: OpenAIV1ChatCompletionSchema, diff --git a/src/shared/api-schemas/openai-image.ts b/src/shared/api-schemas/openai-image.ts index afae1dd..7133362 100644 --- a/src/shared/api-schemas/openai-image.ts +++ b/src/shared/api-schemas/openai-image.ts @@ -1,6 +1,6 @@ import { z } from "zod"; -import { Request } from "express"; import { OpenAIV1ChatCompletionSchema } from "./openai"; +import { APIFormatTransformer } from "./index"; // https://platform.openai.com/docs/api-reference/images/create export const OpenAIV1ImagesGenerationSchema = z @@ -20,47 +20,49 @@ export const OpenAIV1ImagesGenerationSchema = z .strip(); // Takes the last chat message and uses it verbatim as the image prompt. -export function openAIToOpenAIImage(req: Request) { - const { body } = req; - const result = OpenAIV1ChatCompletionSchema.safeParse(body); - if (!result.success) { - req.log.warn( - { issues: result.error.issues, body }, - "Invalid OpenAI-to-OpenAI-image request", - ); - throw result.error; - } +export const transformOpenAIToOpenAIImage: APIFormatTransformer< + typeof OpenAIV1ImagesGenerationSchema +> = async (req) => { + const { body } = req; + const result = OpenAIV1ChatCompletionSchema.safeParse(body); + if (!result.success) { + req.log.warn( + { issues: result.error.issues, body }, + "Invalid OpenAI-to-OpenAI-image request" + ); + throw result.error; + } - const { messages } = result.data; - const prompt = messages.filter((m) => m.role === "user").pop()?.content; - if (Array.isArray(prompt)) { - throw new Error("Image generation prompt must be a text message."); - } + const { messages } = result.data; + const prompt = messages.filter((m) => m.role === "user").pop()?.content; + if (Array.isArray(prompt)) { + throw new Error("Image generation prompt must be a text message."); + } - if (body.stream) { - throw new Error( - "Streaming is not supported for image generation requests.", - ); - } + if (body.stream) { + throw new Error( + "Streaming is not supported for image generation requests." + ); + } - // Some frontends do weird things with the prompt, like prefixing it with a - // character name or wrapping the entire thing in quotes. We will look for - // the index of "Image:" and use everything after that as the prompt. + // Some frontends do weird things with the prompt, like prefixing it with a + // character name or wrapping the entire thing in quotes. We will look for + // the index of "Image:" and use everything after that as the prompt. - const index = prompt?.toLowerCase().indexOf("image:"); - if (index === -1 || !prompt) { - throw new Error( - `Start your prompt with 'Image:' followed by a description of the image you want to generate (received: ${prompt}).`, - ); - } + const index = prompt?.toLowerCase().indexOf("image:"); + if (index === -1 || !prompt) { + throw new Error( + `Start your prompt with 'Image:' followed by a description of the image you want to generate (received: ${prompt}).` + ); + } - // TODO: Add some way to specify parameters via chat message - const transformed = { - model: body.model.includes("dall-e") ? body.model : "dall-e-3", - quality: "standard", - size: "1024x1024", - response_format: "url", - prompt: prompt.slice(index! + 6).trim(), - }; - return OpenAIV1ImagesGenerationSchema.parse(transformed); -} + // TODO: Add some way to specify parameters via chat message + const transformed = { + model: body.model.includes("dall-e") ? body.model : "dall-e-3", + quality: "standard", + size: "1024x1024", + response_format: "url", + prompt: prompt.slice(index! + 6).trim(), + }; + return OpenAIV1ImagesGenerationSchema.parse(transformed); +}; diff --git a/src/shared/api-schemas/openai-text.ts b/src/shared/api-schemas/openai-text.ts index 71cfad8..2cefd35 100644 --- a/src/shared/api-schemas/openai-text.ts +++ b/src/shared/api-schemas/openai-text.ts @@ -3,7 +3,7 @@ import { flattenOpenAIChatMessages, OpenAIV1ChatCompletionSchema, } from "./openai"; -import { Request } from "express"; +import { APIFormatTransformer } from "./index"; export const OpenAIV1TextCompletionSchema = z .object({ @@ -29,7 +29,9 @@ export const OpenAIV1TextCompletionSchema = z .strip() .merge(OpenAIV1ChatCompletionSchema.omit({ messages: true, logprobs: true })); -export function openAIToOpenAIText(req: Request) { +export const transformOpenAIToOpenAIText: APIFormatTransformer< + typeof OpenAIV1TextCompletionSchema +> = async (req) => { const { body } = req; const result = OpenAIV1ChatCompletionSchema.safeParse(body); if (!result.success) { @@ -53,4 +55,4 @@ export function openAIToOpenAIText(req: Request) { const transformed = { ...rest, prompt: prompt, stop: stops }; return OpenAIV1TextCompletionSchema.parse(transformed); -} +}; diff --git a/src/shared/users/user-store.ts b/src/shared/users/user-store.ts index 9a8c5f2..080d1f7 100644 --- a/src/shared/users/user-store.ts +++ b/src/shared/users/user-store.ts @@ -338,12 +338,13 @@ function refreshAllQuotas() { // store to sync it with Firebase when it changes. Will refactor to abstract // persistence layer later so we can support multiple stores. let firebaseTimeout: NodeJS.Timeout | undefined; +const USERS_REF = process.env.FIREBASE_USERS_REF_NAME ?? "users"; async function initFirebase() { log.info("Connecting to Firebase..."); const app = getFirebaseApp(); const db = admin.database(app); - const usersRef = db.ref("users"); + const usersRef = db.ref(USERS_REF); const snapshot = await usersRef.once("value"); const users: Record | null = snapshot.val(); firebaseTimeout = setInterval(flushUsers, 20 * 1000); @@ -362,7 +363,7 @@ async function initFirebase() { async function flushUsers() { const app = getFirebaseApp(); const db = admin.database(app); - const usersRef = db.ref("users"); + const usersRef = db.ref(USERS_REF); const updates: Record = {}; const deletions = []; diff --git a/tsconfig.json b/tsconfig.json index a1762f4..3db51a1 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -1,7 +1,7 @@ { "compilerOptions": { "strict": true, - "target": "ES2020", + "target": "ES2022", "module": "CommonJS", "moduleResolution": "node", "esModuleInterop": true,