feat(router): add legacy route for api-inference support (#88)

This commit is contained in:
OlivierDehaene 2023-02-27 14:56:58 +01:00 committed by GitHub
parent 65e2f1624e
commit 21340f24ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 7 deletions

View File

@ -86,6 +86,26 @@ pub(crate) struct GenerateRequest {
pub parameters: GenerateParameters,
}
#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct CompatGenerateRequest {
#[schema(example = "My name is Olivier and I")]
pub inputs: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,
#[serde(default)]
#[allow(dead_code)]
pub stream: bool,
}
impl From<CompatGenerateRequest> for GenerateRequest {
fn from(req: CompatGenerateRequest) -> Self {
Self {
inputs: req.inputs,
parameters: req.parameters,
}
}
}
#[derive(Debug, Serialize, ToSchema)]
pub struct PrefillToken {
#[schema(example = 0)]

View File

@ -1,8 +1,9 @@
/// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse};
use crate::{
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
Infer, PrefillToken, StreamDetails, StreamResponse, Token, Validation,
CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, StreamResponse, Token,
Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
@ -25,6 +26,25 @@ use tracing::{info_span, instrument, Instrument};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
/// Compatibility route with api-inference and AzureML
#[instrument(skip(infer))]
async fn compat_generate(
infer: Extension<Infer>,
req: Json<CompatGenerateRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
// switch on stream
let req = req.0;
if req.stream {
Ok(generate_stream(infer, Json(req.into()))
.await
.into_response())
} else {
let (headers, generation) = generate(infer, Json(req.into())).await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![generation.0])).into_response())
}
}
/// Health check method
#[instrument(skip(infer))]
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
@ -84,7 +104,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
async fn generate(
infer: Extension<Infer>,
req: Json<GenerateRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
let start_time = Instant::now();
@ -404,7 +424,7 @@ pub async fn run(
// Create router
let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
.route("/", post(generate))
.route("/", post(compat_generate))
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
.route("/", get(health))