diff --git a/router/src/lib.rs b/router/src/lib.rs index 1f23bfd3..d7cfa4c7 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -47,7 +47,7 @@ pub(crate) struct GenerateParameters { #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] pub max_new_tokens: u32, #[serde(default)] - #[schema(inline, max_items = 4, example = json!(["photographer"]))] + #[schema(inline, max_items = 4, example = json ! (["photographer"]))] pub stop: Vec, #[serde(default)] #[schema(default = "true")] @@ -86,13 +86,33 @@ pub(crate) struct GenerateRequest { pub parameters: GenerateParameters, } +#[derive(Clone, Debug, Deserialize, ToSchema)] +pub(crate) struct CompatGenerateRequest { + #[schema(example = "My name is Olivier and I")] + pub inputs: String, + #[serde(default = "default_parameters")] + pub parameters: GenerateParameters, + #[serde(default)] + #[allow(dead_code)] + pub stream: bool, +} + +impl From for GenerateRequest { + fn from(req: CompatGenerateRequest) -> Self { + Self { + inputs: req.inputs, + parameters: req.parameters, + } + } +} + #[derive(Debug, Serialize, ToSchema)] pub struct PrefillToken { #[schema(example = 0)] id: u32, #[schema(example = "test")] text: String, - #[schema(nullable = true, example = -0.34)] + #[schema(nullable = true, example = - 0.34)] logprob: f32, } @@ -102,7 +122,7 @@ pub struct Token { id: u32, #[schema(example = "test")] text: String, - #[schema(nullable = true, example = -0.34)] + #[schema(nullable = true, example = - 0.34)] logprob: f32, #[schema(example = "false")] special: bool, diff --git a/router/src/server.rs b/router/src/server.rs index de96e397..83b0297e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,8 +1,9 @@ /// HTTP Server logic use crate::infer::{InferError, InferStreamResponse}; use crate::{ - Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, - Infer, PrefillToken, StreamDetails, StreamResponse, Token, Validation, + CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters, + GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, StreamResponse, Token, + Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -25,6 +26,25 @@ use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; +/// Compatibility route with api-inference and AzureML +#[instrument(skip(infer))] +async fn compat_generate( + infer: Extension, + req: Json, +) -> Result)> { + // switch on stream + let req = req.0; + if req.stream { + Ok(generate_stream(infer, Json(req.into())) + .await + .into_response()) + } else { + let (headers, generation) = generate(infer, Json(req.into())).await?; + // wrap generation inside a Vec to match api-inference + Ok((headers, Json(vec![generation.0])).into_response()) + } +} + /// Health check method #[instrument(skip(infer))] async fn health(infer: Extension) -> Result<(), (StatusCode, Json)> { @@ -84,7 +104,7 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json, req: Json, -) -> Result)> { +) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); let start_time = Instant::now(); @@ -404,7 +424,7 @@ pub async fn run( // Create router let app = Router::new() .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) - .route("/", post(generate)) + .route("/", post(compat_generate)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/", get(health))