feat(router): add legacy route for api-inference support (#88)

This commit is contained in:
OlivierDehaene 2023-02-27 14:56:58 +01:00 committed by GitHub
parent 65e2f1624e
commit 21340f24ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 7 deletions

View File

@ -47,7 +47,7 @@ pub(crate) struct GenerateParameters {
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
pub max_new_tokens: u32, pub max_new_tokens: u32,
#[serde(default)] #[serde(default)]
#[schema(inline, max_items = 4, example = json!(["photographer"]))] #[schema(inline, max_items = 4, example = json ! (["photographer"]))]
pub stop: Vec<String>, pub stop: Vec<String>,
#[serde(default)] #[serde(default)]
#[schema(default = "true")] #[schema(default = "true")]
@ -86,13 +86,33 @@ pub(crate) struct GenerateRequest {
pub parameters: GenerateParameters, pub parameters: GenerateParameters,
} }
#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct CompatGenerateRequest {
#[schema(example = "My name is Olivier and I")]
pub inputs: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,
#[serde(default)]
#[allow(dead_code)]
pub stream: bool,
}
impl From<CompatGenerateRequest> for GenerateRequest {
fn from(req: CompatGenerateRequest) -> Self {
Self {
inputs: req.inputs,
parameters: req.parameters,
}
}
}
#[derive(Debug, Serialize, ToSchema)] #[derive(Debug, Serialize, ToSchema)]
pub struct PrefillToken { pub struct PrefillToken {
#[schema(example = 0)] #[schema(example = 0)]
id: u32, id: u32,
#[schema(example = "test")] #[schema(example = "test")]
text: String, text: String,
#[schema(nullable = true, example = -0.34)] #[schema(nullable = true, example = - 0.34)]
logprob: f32, logprob: f32,
} }
@ -102,7 +122,7 @@ pub struct Token {
id: u32, id: u32,
#[schema(example = "test")] #[schema(example = "test")]
text: String, text: String,
#[schema(nullable = true, example = -0.34)] #[schema(nullable = true, example = - 0.34)]
logprob: f32, logprob: f32,
#[schema(example = "false")] #[schema(example = "false")]
special: bool, special: bool,

View File

@ -1,8 +1,9 @@
/// HTTP Server logic /// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse}; use crate::infer::{InferError, InferStreamResponse};
use crate::{ use crate::{
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
Infer, PrefillToken, StreamDetails, StreamResponse, Token, Validation, GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, StreamResponse, Token,
Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
@ -25,6 +26,25 @@ use tracing::{info_span, instrument, Instrument};
use utoipa::OpenApi; use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
/// Compatibility route with api-inference and AzureML
#[instrument(skip(infer))]
async fn compat_generate(
infer: Extension<Infer>,
req: Json<CompatGenerateRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
// switch on stream
let req = req.0;
if req.stream {
Ok(generate_stream(infer, Json(req.into()))
.await
.into_response())
} else {
let (headers, generation) = generate(infer, Json(req.into())).await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![generation.0])).into_response())
}
}
/// Health check method /// Health check method
#[instrument(skip(infer))] #[instrument(skip(infer))]
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> { async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
@ -84,7 +104,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
async fn generate( async fn generate(
infer: Extension<Infer>, infer: Extension<Infer>,
req: Json<GenerateRequest>, req: Json<GenerateRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> { ) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
let start_time = Instant::now(); let start_time = Instant::now();
@ -404,7 +424,7 @@ pub async fn run(
// Create router // Create router
let app = Router::new() let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
.route("/", post(generate)) .route("/", post(compat_generate))
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
.route("/", get(health)) .route("/", get(health))