wip
This commit is contained in:
parent
235510e588
commit
6f7abf0220
|
@ -0,0 +1,58 @@
|
|||
/* Provides a single endpoint for all services. */
|
||||
import { RequestHandler } from "express";
|
||||
import { generateErrorMessage } from "zod-error";
|
||||
import { APIFormat } from "../shared/key-management";
|
||||
import {
|
||||
getServiceForModel,
|
||||
LLMService,
|
||||
MODEL_FAMILIES,
|
||||
MODEL_FAMILY_SERVICE,
|
||||
ModelFamily,
|
||||
} from "../shared/models";
|
||||
import { API_SCHEMA_VALIDATORS } from "../shared/api-schemas";
|
||||
|
||||
const detectApiFormat = (body: any, formats: APIFormat[]): APIFormat => {
|
||||
const errors = [];
|
||||
for (const format of formats) {
|
||||
const result = API_SCHEMA_VALIDATORS[format].safeParse(body);
|
||||
if (result.success) {
|
||||
return format;
|
||||
} else {
|
||||
errors.push(result.error);
|
||||
}
|
||||
}
|
||||
throw new Error(`Couldn't determine the format of your request. Errors: ${errors}`);
|
||||
};
|
||||
|
||||
/**
|
||||
* Tries to infer LLMService and APIFormat using the model name and the presence
|
||||
* of certain fields in the request body.
|
||||
*/
|
||||
const inferService: RequestHandler = (req, res, next) => {
|
||||
const model = req.body.model;
|
||||
if (!model) {
|
||||
throw new Error("No model specified");
|
||||
}
|
||||
|
||||
// Service determines the key provider and is typically determined by the
|
||||
// requested model, though some models are served by multiple services.
|
||||
// API format determines the expected request/response format.
|
||||
let service: LLMService;
|
||||
let inboundApi: APIFormat;
|
||||
let outboundApi: APIFormat;
|
||||
|
||||
if (MODEL_FAMILIES.includes(model)) {
|
||||
service = MODEL_FAMILY_SERVICE[model as ModelFamily];
|
||||
} else {
|
||||
service = getServiceForModel(model);
|
||||
}
|
||||
|
||||
// Each service has typically one API format.
|
||||
switch (service) {
|
||||
case "openai": {
|
||||
const detected = detectApiFormat(req.body, ["openai", "openai-text", "openai-image"]);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
};
|
|
@ -4,8 +4,13 @@ import os from "os";
|
|||
import schedule from "node-schedule";
|
||||
import { config } from "../../config";
|
||||
import { logger } from "../../logger";
|
||||
import { LLMService, MODEL_FAMILY_SERVICE, ModelFamily } from "../models";
|
||||
import { Key, Model, KeyProvider } from "./index";
|
||||
import {
|
||||
getServiceForModel,
|
||||
LLMService,
|
||||
MODEL_FAMILY_SERVICE,
|
||||
ModelFamily,
|
||||
} from "../models";
|
||||
import { Key, KeyProvider, Model } from "./index";
|
||||
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
|
||||
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
|
||||
import { GoogleAIKeyProvider } from "./google-ai/provider";
|
||||
|
@ -42,7 +47,7 @@ export class KeyPool {
|
|||
}
|
||||
|
||||
public get(model: Model): Key {
|
||||
const service = this.getServiceForModel(model);
|
||||
const service = getServiceForModel(model);
|
||||
return this.getKeyProvider(service).get(model);
|
||||
}
|
||||
|
||||
|
@ -72,7 +77,7 @@ export class KeyPool {
|
|||
public available(model: Model | "all" = "all"): number {
|
||||
return this.keyProviders.reduce((sum, provider) => {
|
||||
const includeProvider =
|
||||
model === "all" || this.getServiceForModel(model) === provider.service;
|
||||
model === "all" || getServiceForModel(model) === provider.service;
|
||||
return sum + (includeProvider ? provider.available() : 0);
|
||||
}, 0);
|
||||
}
|
||||
|
@ -109,33 +114,6 @@ export class KeyPool {
|
|||
provider.recheck();
|
||||
}
|
||||
|
||||
private getServiceForModel(model: Model): LLMService {
|
||||
if (
|
||||
model.startsWith("gpt") ||
|
||||
model.startsWith("text-embedding-ada") ||
|
||||
model.startsWith("dall-e")
|
||||
) {
|
||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
return "openai";
|
||||
} else if (model.startsWith("claude-")) {
|
||||
// https://console.anthropic.com/docs/api/reference#parameters
|
||||
return "anthropic";
|
||||
} else if (model.includes("gemini")) {
|
||||
// https://developers.generativeai.google.com/models/language
|
||||
return "google-ai";
|
||||
} else if (model.includes("mistral")) {
|
||||
// https://docs.mistral.ai/platform/endpoints
|
||||
return "mistral-ai";
|
||||
} else if (model.startsWith("anthropic.claude")) {
|
||||
// AWS offers models from a few providers
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||
return "aws";
|
||||
} else if (model.startsWith("azure")) {
|
||||
return "azure";
|
||||
}
|
||||
throw new Error(`Unknown service for model '${model}'`);
|
||||
}
|
||||
|
||||
private getKeyProvider(service: LLMService): KeyProvider {
|
||||
return this.keyProviders.find((provider) => provider.service === service)!;
|
||||
}
|
||||
|
|
|
@ -205,6 +205,33 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
|
|||
return (req.modelFamily = modelFamily);
|
||||
}
|
||||
|
||||
export function getServiceForModel(model: string): LLMService {
|
||||
if (
|
||||
model.startsWith("gpt") ||
|
||||
model.startsWith("text-embedding-ada") ||
|
||||
model.startsWith("dall-e")
|
||||
) {
|
||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
return "openai";
|
||||
} else if (model.startsWith("claude-")) {
|
||||
// https://console.anthropic.com/docs/api/reference#parameters
|
||||
return "anthropic";
|
||||
} else if (model.includes("gemini")) {
|
||||
// https://developers.generativeai.google.com/models/language
|
||||
return "google-ai";
|
||||
} else if (model.includes("mistral")) {
|
||||
// https://docs.mistral.ai/platform/endpoints
|
||||
return "mistral-ai";
|
||||
} else if (model.startsWith("anthropic.claude")) {
|
||||
// AWS offers models from a few providers
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||
return "aws";
|
||||
} else if (model.startsWith("azure")) {
|
||||
return "azure";
|
||||
}
|
||||
throw new Error(`Unknown service for model '${model}'`);
|
||||
}
|
||||
|
||||
function assertNever(x: never): never {
|
||||
throw new Error(`Called assertNever with argument ${x}.`);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue