This commit is contained in:
nai-degen 2024-02-04 13:31:27 -06:00
parent 235510e588
commit 6f7abf0220
3 changed files with 94 additions and 31 deletions

58
src/proxy/combined.ts Normal file
View File

@ -0,0 +1,58 @@
/* Provides a single endpoint for all services. */
import { RequestHandler } from "express";
import { generateErrorMessage } from "zod-error";
import { APIFormat } from "../shared/key-management";
import {
getServiceForModel,
LLMService,
MODEL_FAMILIES,
MODEL_FAMILY_SERVICE,
ModelFamily,
} from "../shared/models";
import { API_SCHEMA_VALIDATORS } from "../shared/api-schemas";
const detectApiFormat = (body: any, formats: APIFormat[]): APIFormat => {
const errors = [];
for (const format of formats) {
const result = API_SCHEMA_VALIDATORS[format].safeParse(body);
if (result.success) {
return format;
} else {
errors.push(result.error);
}
}
throw new Error(`Couldn't determine the format of your request. Errors: ${errors}`);
};
/**
* Tries to infer LLMService and APIFormat using the model name and the presence
* of certain fields in the request body.
*/
const inferService: RequestHandler = (req, res, next) => {
const model = req.body.model;
if (!model) {
throw new Error("No model specified");
}
// Service determines the key provider and is typically determined by the
// requested model, though some models are served by multiple services.
// API format determines the expected request/response format.
let service: LLMService;
let inboundApi: APIFormat;
let outboundApi: APIFormat;
if (MODEL_FAMILIES.includes(model)) {
service = MODEL_FAMILY_SERVICE[model as ModelFamily];
} else {
service = getServiceForModel(model);
}
// Each service has typically one API format.
switch (service) {
case "openai": {
const detected = detectApiFormat(req.body, ["openai", "openai-text", "openai-image"]);
}
}
};

View File

@ -4,8 +4,13 @@ import os from "os";
import schedule from "node-schedule";
import { config } from "../../config";
import { logger } from "../../logger";
import { LLMService, MODEL_FAMILY_SERVICE, ModelFamily } from "../models";
import { Key, Model, KeyProvider } from "./index";
import {
getServiceForModel,
LLMService,
MODEL_FAMILY_SERVICE,
ModelFamily,
} from "../models";
import { Key, KeyProvider, Model } from "./index";
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
import { GoogleAIKeyProvider } from "./google-ai/provider";
@ -42,7 +47,7 @@ export class KeyPool {
}
public get(model: Model): Key {
const service = this.getServiceForModel(model);
const service = getServiceForModel(model);
return this.getKeyProvider(service).get(model);
}
@ -72,7 +77,7 @@ export class KeyPool {
public available(model: Model | "all" = "all"): number {
return this.keyProviders.reduce((sum, provider) => {
const includeProvider =
model === "all" || this.getServiceForModel(model) === provider.service;
model === "all" || getServiceForModel(model) === provider.service;
return sum + (includeProvider ? provider.available() : 0);
}, 0);
}
@ -109,33 +114,6 @@ export class KeyPool {
provider.recheck();
}
private getServiceForModel(model: Model): LLMService {
if (
model.startsWith("gpt") ||
model.startsWith("text-embedding-ada") ||
model.startsWith("dall-e")
) {
// https://platform.openai.com/docs/models/model-endpoint-compatibility
return "openai";
} else if (model.startsWith("claude-")) {
// https://console.anthropic.com/docs/api/reference#parameters
return "anthropic";
} else if (model.includes("gemini")) {
// https://developers.generativeai.google.com/models/language
return "google-ai";
} else if (model.includes("mistral")) {
// https://docs.mistral.ai/platform/endpoints
return "mistral-ai";
} else if (model.startsWith("anthropic.claude")) {
// AWS offers models from a few providers
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
return "aws";
} else if (model.startsWith("azure")) {
return "azure";
}
throw new Error(`Unknown service for model '${model}'`);
}
private getKeyProvider(service: LLMService): KeyProvider {
return this.keyProviders.find((provider) => provider.service === service)!;
}

View File

@ -205,6 +205,33 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
return (req.modelFamily = modelFamily);
}
export function getServiceForModel(model: string): LLMService {
if (
model.startsWith("gpt") ||
model.startsWith("text-embedding-ada") ||
model.startsWith("dall-e")
) {
// https://platform.openai.com/docs/models/model-endpoint-compatibility
return "openai";
} else if (model.startsWith("claude-")) {
// https://console.anthropic.com/docs/api/reference#parameters
return "anthropic";
} else if (model.includes("gemini")) {
// https://developers.generativeai.google.com/models/language
return "google-ai";
} else if (model.includes("mistral")) {
// https://docs.mistral.ai/platform/endpoints
return "mistral-ai";
} else if (model.startsWith("anthropic.claude")) {
// AWS offers models from a few providers
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
return "aws";
} else if (model.startsWith("azure")) {
return "azure";
}
throw new Error(`Unknown service for model '${model}'`);
}
function assertNever(x: never): never {
throw new Error(`Called assertNever with argument ${x}.`);
}