feat(router): add legacy route for api-inference support (#88)
This commit is contained in:
parent
65e2f1624e
commit
21340f24ba
|
@ -47,7 +47,7 @@ pub(crate) struct GenerateParameters {
|
|||
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
|
||||
pub max_new_tokens: u32,
|
||||
#[serde(default)]
|
||||
#[schema(inline, max_items = 4, example = json!(["photographer"]))]
|
||||
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
|
||||
pub stop: Vec<String>,
|
||||
#[serde(default)]
|
||||
#[schema(default = "true")]
|
||||
|
@ -86,13 +86,33 @@ pub(crate) struct GenerateRequest {
|
|||
pub parameters: GenerateParameters,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
pub(crate) struct CompatGenerateRequest {
|
||||
#[schema(example = "My name is Olivier and I")]
|
||||
pub inputs: String,
|
||||
#[serde(default = "default_parameters")]
|
||||
pub parameters: GenerateParameters,
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)]
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
impl From<CompatGenerateRequest> for GenerateRequest {
|
||||
fn from(req: CompatGenerateRequest) -> Self {
|
||||
Self {
|
||||
inputs: req.inputs,
|
||||
parameters: req.parameters,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct PrefillToken {
|
||||
#[schema(example = 0)]
|
||||
id: u32,
|
||||
#[schema(example = "test")]
|
||||
text: String,
|
||||
#[schema(nullable = true, example = -0.34)]
|
||||
#[schema(nullable = true, example = - 0.34)]
|
||||
logprob: f32,
|
||||
}
|
||||
|
||||
|
@ -102,7 +122,7 @@ pub struct Token {
|
|||
id: u32,
|
||||
#[schema(example = "test")]
|
||||
text: String,
|
||||
#[schema(nullable = true, example = -0.34)]
|
||||
#[schema(nullable = true, example = - 0.34)]
|
||||
logprob: f32,
|
||||
#[schema(example = "false")]
|
||||
special: bool,
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
/// HTTP Server logic
|
||||
use crate::infer::{InferError, InferStreamResponse};
|
||||
use crate::{
|
||||
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
|
||||
Infer, PrefillToken, StreamDetails, StreamResponse, Token, Validation,
|
||||
CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
|
||||
GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, StreamResponse, Token,
|
||||
Validation,
|
||||
};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
|
@ -25,6 +26,25 @@ use tracing::{info_span, instrument, Instrument};
|
|||
use utoipa::OpenApi;
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
|
||||
/// Compatibility route with api-inference and AzureML
|
||||
#[instrument(skip(infer))]
|
||||
async fn compat_generate(
|
||||
infer: Extension<Infer>,
|
||||
req: Json<CompatGenerateRequest>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||||
// switch on stream
|
||||
let req = req.0;
|
||||
if req.stream {
|
||||
Ok(generate_stream(infer, Json(req.into()))
|
||||
.await
|
||||
.into_response())
|
||||
} else {
|
||||
let (headers, generation) = generate(infer, Json(req.into())).await?;
|
||||
// wrap generation inside a Vec to match api-inference
|
||||
Ok((headers, Json(vec![generation.0])).into_response())
|
||||
}
|
||||
}
|
||||
|
||||
/// Health check method
|
||||
#[instrument(skip(infer))]
|
||||
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||
|
@ -84,7 +104,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
|||
async fn generate(
|
||||
infer: Extension<Infer>,
|
||||
req: Json<GenerateRequest>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
let start_time = Instant::now();
|
||||
|
||||
|
@ -404,7 +424,7 @@ pub async fn run(
|
|||
// Create router
|
||||
let app = Router::new()
|
||||
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
||||
.route("/", post(generate))
|
||||
.route("/", post(compat_generate))
|
||||
.route("/generate", post(generate))
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
.route("/", get(health))
|
||||
|
|
Loading…
Reference in New Issue