feat(router): add legacy route for api-inference support (#88)
This commit is contained in:
parent
65e2f1624e
commit
21340f24ba
|
@ -47,7 +47,7 @@ pub(crate) struct GenerateParameters {
|
||||||
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
|
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
|
||||||
pub max_new_tokens: u32,
|
pub max_new_tokens: u32,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(inline, max_items = 4, example = json!(["photographer"]))]
|
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
|
||||||
pub stop: Vec<String>,
|
pub stop: Vec<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default = "true")]
|
#[schema(default = "true")]
|
||||||
|
@ -86,13 +86,33 @@ pub(crate) struct GenerateRequest {
|
||||||
pub parameters: GenerateParameters,
|
pub parameters: GenerateParameters,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||||
|
pub(crate) struct CompatGenerateRequest {
|
||||||
|
#[schema(example = "My name is Olivier and I")]
|
||||||
|
pub inputs: String,
|
||||||
|
#[serde(default = "default_parameters")]
|
||||||
|
pub parameters: GenerateParameters,
|
||||||
|
#[serde(default)]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub stream: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<CompatGenerateRequest> for GenerateRequest {
|
||||||
|
fn from(req: CompatGenerateRequest) -> Self {
|
||||||
|
Self {
|
||||||
|
inputs: req.inputs,
|
||||||
|
parameters: req.parameters,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, ToSchema)]
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
pub struct PrefillToken {
|
pub struct PrefillToken {
|
||||||
#[schema(example = 0)]
|
#[schema(example = 0)]
|
||||||
id: u32,
|
id: u32,
|
||||||
#[schema(example = "test")]
|
#[schema(example = "test")]
|
||||||
text: String,
|
text: String,
|
||||||
#[schema(nullable = true, example = -0.34)]
|
#[schema(nullable = true, example = - 0.34)]
|
||||||
logprob: f32,
|
logprob: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,7 +122,7 @@ pub struct Token {
|
||||||
id: u32,
|
id: u32,
|
||||||
#[schema(example = "test")]
|
#[schema(example = "test")]
|
||||||
text: String,
|
text: String,
|
||||||
#[schema(nullable = true, example = -0.34)]
|
#[schema(nullable = true, example = - 0.34)]
|
||||||
logprob: f32,
|
logprob: f32,
|
||||||
#[schema(example = "false")]
|
#[schema(example = "false")]
|
||||||
special: bool,
|
special: bool,
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
use crate::infer::{InferError, InferStreamResponse};
|
use crate::infer::{InferError, InferStreamResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
|
CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
|
||||||
Infer, PrefillToken, StreamDetails, StreamResponse, Token, Validation,
|
GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, StreamResponse, Token,
|
||||||
|
Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
|
@ -25,6 +26,25 @@ use tracing::{info_span, instrument, Instrument};
|
||||||
use utoipa::OpenApi;
|
use utoipa::OpenApi;
|
||||||
use utoipa_swagger_ui::SwaggerUi;
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
|
/// Compatibility route with api-inference and AzureML
|
||||||
|
#[instrument(skip(infer))]
|
||||||
|
async fn compat_generate(
|
||||||
|
infer: Extension<Infer>,
|
||||||
|
req: Json<CompatGenerateRequest>,
|
||||||
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
// switch on stream
|
||||||
|
let req = req.0;
|
||||||
|
if req.stream {
|
||||||
|
Ok(generate_stream(infer, Json(req.into()))
|
||||||
|
.await
|
||||||
|
.into_response())
|
||||||
|
} else {
|
||||||
|
let (headers, generation) = generate(infer, Json(req.into())).await?;
|
||||||
|
// wrap generation inside a Vec to match api-inference
|
||||||
|
Ok((headers, Json(vec![generation.0])).into_response())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Health check method
|
/// Health check method
|
||||||
#[instrument(skip(infer))]
|
#[instrument(skip(infer))]
|
||||||
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
@ -84,7 +104,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
||||||
async fn generate(
|
async fn generate(
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
req: Json<GenerateRequest>,
|
req: Json<GenerateRequest>,
|
||||||
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
@ -404,7 +424,7 @@ pub async fn run(
|
||||||
// Create router
|
// Create router
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
||||||
.route("/", post(generate))
|
.route("/", post(compat_generate))
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
.route("/", get(health))
|
.route("/", get(health))
|
||||||
|
|
Loading…
Reference in New Issue