From 012c917b6fb37d0351f88fda73d3a81417a18d6f Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Wed, 27 Sep 2023 17:58:07 +0200 Subject: [PATCH] Wrapping completions and chat/completions endpoint (#2) * rebase and squash commits on latest main * cargo fmt * fix: 2038y problem --------- Co-authored-by: michaelfeil --- Dockerfile | 2 +- docs/openapi.json | 178 ++++++++++++ router/src/completion.rs | 610 +++++++++++++++++++++++++++++++++++++++ router/src/lib.rs | 25 +- router/src/server.rs | 298 ++++++++++++++++++- 5 files changed, 1107 insertions(+), 6 deletions(-) create mode 100644 router/src/completion.rs diff --git a/Dockerfile b/Dockerfile index 34109d0..587ab9b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -39,7 +39,7 @@ RUN cargo build --release # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile FROM debian:bullseye-slim as pytorch-install -ARG PYTORCH_VERSION=2.0.0 +ARG PYTORCH_VERSION=2.0.1 ARG PYTHON_VERSION=3.9 ARG CUDA_VERSION=11.8 ARG MAMBA_VERSION=23.1.0-1 diff --git a/docs/openapi.json b/docs/openapi.json index 9a67238..ade487e 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -102,6 +102,184 @@ } } }, + "/completions": { + "post": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Completion request. Enable stream of token by setting `stream == true`", + "description": "Completion request. Enable stream of token by setting `stream == true`", + "operationId": "completions_generate", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CompatCompletionRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Generated Completion", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CompletionsResponse" + } + }, + "text/event-stream": { + "schema": { + "$ref": "#/components/schemas/CompletionsResponse" + } + } + } + }, + "422": { + "description": "Input validation error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Input validation error" + } + } + } + }, + "424": { + "description": "Generation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Request failed during generation" + } + } + } + }, + "429": { + "description": "Model is overloaded", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Model is overloaded" + } + } + } + }, + "500": { + "description": "Incomplete generation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Incomplete generation" + } + } + } + } + } + } + }, + "/chat/completions": { + "post": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Generate tokens via Chat", + "description": "Generate tokens via Chat", + "operationId": "chatcompletions_generate", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CompatChatCompletionRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Generated Completion", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ChatCompletionsResponse" + } + }, + "text/event-stream": { + "schema": { + "$ref": "#/components/schemas/ChatCompletionsStreamResponse" + } + } + } + }, + "422": { + "description": "Input validation error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Input validation error" + } + } + } + }, + "424": { + "description": "Generation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Request failed during generation" + } + } + } + }, + "429": { + "description": "Model is overloaded", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Model is overloaded" + } + } + } + }, + "500": { + "description": "Incomplete generation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Incomplete generation" + } + } + } + } + } + } + }, "/generate": { "post": { "tags": [ diff --git a/router/src/completion.rs b/router/src/completion.rs new file mode 100644 index 0000000..dede1a3 --- /dev/null +++ b/router/src/completion.rs @@ -0,0 +1,610 @@ +/// Copyright 2023 Michael Feil, text-generation-inference contributors +/// +/// Licensed under the Apache License, Version 2.0 (the "License"); +/// you may not use this file except in compliance with the License. +/// You may obtain a copy of the License at +/// +/// http://www.apache.org/licenses/LICENSE-2.0 +/// +/// Unless required by applicable law or agreed to in writing, software +/// distributed under the License is distributed on an "AS IS" BASIS, +/// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +/// See the License for the specific language governing permissions and +/// limitations under the License. +/// + +/// Converting generate to completions and chat/completions protocol +use crate::{ + default_max_new_tokens, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, + Info, OpenaiStreamType, StreamDetails, Token, +}; +use axum::extract::Extension; +use axum::response::sse::Event; +use axum::Json; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; +use std::time::{SystemTime, UNIX_EPOCH}; + +#[derive(Clone, Debug, Deserialize, ToSchema)] +pub(crate) struct CompatCompletionRequest { + #[schema(example = "My name is Michael and I")] + pub prompt: String, + #[serde(default)] + #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)] + pub best_of: Option, + #[serde(default)] + #[schema( + exclusive_minimum = 0.0, + nullable = true, + default = "null", + example = 0.5 + )] + pub temperature: Option, + #[serde(default)] + #[schema( + exclusive_minimum = -2.0, + nullable = true, + default = "null", + example = 0.0 + )] + pub presence_penalty: Option, + // #[serde(default)] + // #[schema(exclusive_minimum = 0, nullable = true, default = 1, example = 1)] + // pub n: Option, + #[serde(default)] + #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)] + pub top_k: Option, + #[serde(default)] + #[schema( + exclusive_minimum = 0.0, + maximum = 1.0, + nullable = true, + default = "null", + example = 0.95 + )] + pub top_p: Option, + #[serde(default)] + #[schema( + exclusive_minimum = 0.0, + maximum = 1.0, + nullable = true, + default = "null", + example = 0.95 + )] + pub typical_p: Option, + #[serde(default)] + #[schema(default = "false", example = true)] + pub do_sample: bool, + #[serde(default = "default_max_new_tokens")] + #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] + pub max_tokens: u32, + #[serde(default)] + #[schema(nullable = true, default = "null", example = false)] + pub echo: Option, + #[serde(default)] + #[schema(inline, max_items = 4, example = json ! (["photographer"]))] + pub stop: Vec, + #[serde(default)] + #[schema(nullable = true, default = "null", example = "null")] + pub truncate: Option, + #[serde(default)] + #[schema(default = "false", example = true)] + pub watermark: bool, + #[serde(default)] + #[schema(default = "false")] + pub decoder_input_details: bool, + #[serde(default)] + #[schema( + exclusive_minimum = 0, + nullable = true, + default = "null", + example = "null" + )] + pub seed: Option, + #[serde(default)] + #[schema(default = "false")] + pub stream: bool, +} + +impl From for GenerateRequest { + fn from(req: CompatCompletionRequest) -> Self { + let presence_penalty = req.presence_penalty; + let presence_penalty = match presence_penalty { + Some(presence_penalty) => Some((presence_penalty + 2.0) / 2.0), + None => None, + }; + Self { + inputs: req.prompt, + parameters: GenerateParameters { + best_of: req.best_of, + temperature: req.temperature, + repetition_penalty: presence_penalty, + top_k: req.top_k, + top_p: req.top_p, + typical_p: req.typical_p, + do_sample: req.do_sample, + max_new_tokens: req.max_tokens, + return_full_text: req.echo, + stop: req.stop, + truncate: req.truncate, + watermark: req.watermark, + details: true, + decoder_input_details: req.decoder_input_details, + seed: req.seed, + }, + } + } +} + +#[derive(Clone, Debug, ToSchema, Deserialize, Serialize)] +pub(crate) enum ChatRole { + #[serde(rename = "user")] + User, + #[serde(rename = "assistant")] + Assistant, + #[serde(rename = "system")] + System, +} + +#[derive(Clone, Debug, Serialize, ToSchema)] +pub(crate) struct ChatFormatterPrePost { + pre: String, + post: String, +} + +#[derive(Clone, Debug, Serialize, ToSchema)] +pub(crate) struct ChatFormatter { + user_template: ChatFormatterPrePost, + assistant_template: ChatFormatterPrePost, + system_template: ChatFormatterPrePost, +} + +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] +pub(crate) struct ChatMessage { + #[schema(example = "user")] + role: ChatRole, + #[schema(example = "What is the capital of Bavaria?")] + content: String, + // user: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] +pub(crate) struct ChatDeltaStreamMessage { + #[schema(example = "user")] + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option, + #[schema(example = "What is the capital of Bavaria?")] + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + // user: Option, +} + +#[derive(Clone, Debug, Deserialize, ToSchema)] +pub(crate) struct CompatChatCompletionRequest { + pub messages: Vec, + #[serde(default)] + #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)] + pub best_of: Option, + #[serde(default)] + #[schema( + exclusive_minimum = 0.0, + nullable = true, + default = "null", + example = 0.5 + )] + pub temperature: Option, + #[serde(default)] + #[schema( + exclusive_minimum = -2.0, + nullable = true, + default = "null", + example = 0.0 + )] + pub presence_penalty: Option, + // #[serde(default)] + // #[schema(exclusive_minimum = 0, nullable = true, default = 1, example = 1)] + // pub n: Option, + #[serde(default)] + #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)] + pub top_k: Option, + #[serde(default)] + #[schema( + exclusive_minimum = 0.0, + maximum = 1.0, + nullable = true, + default = "null", + example = 0.95 + )] + pub top_p: Option, + #[serde(default)] + #[schema( + exclusive_minimum = 0.0, + maximum = 1.0, + nullable = true, + default = "null", + example = 0.95 + )] + pub typical_p: Option, + #[serde(default)] + #[schema(default = "false", example = true)] + pub do_sample: bool, + #[serde(default = "default_max_new_tokens")] + #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] + pub max_tokens: u32, + #[serde(default)] + #[schema(nullable = true, default = "null", example = false)] + pub echo: Option, + #[serde(default)] + #[schema(inline, max_items = 4, example = json ! (["photographer"]))] + pub stop: Vec, + #[serde(default)] + #[schema(nullable = true, default = "null", example = "null")] + pub truncate: Option, + #[serde(default)] + #[schema(default = "false", example = true)] + pub watermark: bool, + #[serde(default)] + #[schema(default = "false")] + pub decoder_input_details: bool, + #[serde(default)] + #[schema( + exclusive_minimum = 0, + nullable = true, + default = "null", + example = "null" + )] + pub seed: Option, + #[serde(default)] + #[schema(default = "false")] + pub stream: bool, + // #[serde(default)] + // #[schema(nullable = true, default = "null", example = "null")] + // pub user: Option, +} + +pub(crate) fn chat_to_generate_request( + req: CompatChatCompletionRequest, + formatter: ChatFormatter, +) -> GenerateRequest { + let mut prompt = String::from(""); + for m in req.messages { + // let role = m.role + let template = match m.role { + ChatRole::Assistant => &formatter.assistant_template, + ChatRole::System => &formatter.system_template, + ChatRole::User => &formatter.user_template, + }; + prompt.push_str(&template.pre); + prompt.push_str(&m.content); + prompt.push_str(&template.post); + } + let presence_penalty = req.presence_penalty; + let presence_penalty = match presence_penalty { + Some(presence_penalty) => Some((presence_penalty + 2.0) / 2.0), + None => None, + }; + + GenerateRequest { + inputs: prompt, + parameters: GenerateParameters { + best_of: req.best_of, + temperature: req.temperature, + repetition_penalty: presence_penalty, + top_k: req.top_k, + top_p: req.top_p, + typical_p: req.typical_p, + do_sample: req.do_sample, + max_new_tokens: req.max_tokens, + return_full_text: req.echo, + stop: req.stop, + truncate: req.truncate, + watermark: req.watermark, + details: true, + decoder_input_details: req.decoder_input_details, + seed: req.seed, + }, + } +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct Usage { + #[schema(example = 1)] + pub total_tokens: u32, + #[schema(example = 1)] + pub completion_tokens: u32, + #[schema(example = 1)] + pub prompt_tokens: u32, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct CompletionChoices { + #[schema(example = "test")] + pub text: String, + #[schema(example = "length")] + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_reason: Option, + // pub generated_tokens: u32, + // logprobs is not implemented, send None + pub logprobs: Option>, + #[schema(example = 0)] + pub index: u32, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct CompletionsResponse { + #[schema(example = "cmpl-abcdefgehij1234")] + pub id: String, + #[schema(example = "text_completion")] + pub object: String, + #[schema(example = 1589478379)] + pub created: u64, + #[schema(example = "tgi")] + pub model: String, + pub choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct ChatCompletionChoices { + #[schema(example = "test")] + pub message: ChatMessage, + #[schema(example = "length")] + pub finish_reason: Option, + // pub generated_tokens: u32, + #[schema(example = 0)] + pub index: u32, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct ChatCompletionDeltaStreamChoices { + #[schema(example = "test")] + pub delta: ChatDeltaStreamMessage, + #[schema(example = "length")] + pub finish_reason: Option, + // pub generated_tokens: u32, + #[schema(example = 0)] + pub index: u32, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct ChatCompletionsResponse { + #[schema(example = "chatcmpl-abcdefgehij1234")] + pub id: String, + #[schema(example = "chat.completion")] + pub object: String, + #[schema(example = 1589478380)] + pub created: u64, + #[schema(example = "tgi")] + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct ChatCompletionsStreamResponse { + #[schema(example = "chatcmpl-abcdefgehij1234")] + pub id: String, + #[schema(example = "chat.completion.chunk")] + pub object: String, + #[schema(example = 1589478380)] + pub created: u64, + #[schema(example = "tgi")] + pub model: String, + pub choices: Vec, +} + +pub(crate) fn get_chatformatter() -> ChatFormatter { + // TODO: improve reading this, e.g. at startup once from a chat_config.json + let chat_user_pre: String = match std::env::var_os("TGICHAT_USER_PRE") { + Some(v) => v.into_string().unwrap(), + None => String::from(""), + }; + let chat_user_post: String = match std::env::var_os("TGICHAT_USER_POST") { + Some(v) => v.into_string().unwrap(), + None => String::from(""), + }; + let chat_ass_pre: String = match std::env::var_os("TGICHAT_ASS_PRE") { + Some(v) => v.into_string().unwrap(), + None => String::from(""), + }; + let chat_ass_post: String = match std::env::var_os("TGICHAT_ASS_POST") { + Some(v) => v.into_string().unwrap(), + None => String::from(""), + }; + let chat_sys_pre: String = match std::env::var_os("TGICHAT_SYS_PRE") { + Some(v) => v.into_string().unwrap(), + None => String::from(""), + }; + let chat_sys_post: String = match std::env::var_os("TGICHAT_SYS_POST") { + Some(v) => v.into_string().unwrap(), + None => String::from(""), + }; + + ChatFormatter { + user_template: ChatFormatterPrePost { + pre: chat_user_pre, + post: chat_user_post, + }, + assistant_template: ChatFormatterPrePost { + pre: chat_ass_pre, + post: chat_ass_post, + }, + system_template: ChatFormatterPrePost { + pre: chat_sys_pre, + post: chat_sys_post, + }, + } +} + +pub(crate) async fn generate_to_completions( + resp: Json, + info: Extension, +) -> Json { + // let details = resp.details.as_ref().ok_or("details missing"); //; + let details = resp.details.as_ref(); + + let gen_tokens = match details { + Some(details) => details.generated_tokens, + None => 0, + }; + let finish_reason = match details { + Some(details) => Some(details.finish_reason.clone()), + None => None, + }; + let prefill_len = match details { + Some(details) => details.prefill.len() as u32, + None => 0, + }; + + let choices = CompletionChoices { + text: resp.generated_text.clone(), + finish_reason: finish_reason, + logprobs: None, + index: 0, + }; + let usage = Some(Usage { + completion_tokens: gen_tokens, + total_tokens: gen_tokens + prefill_len, + prompt_tokens: prefill_len, + }); + let created_time = create_timestamp(); + let model = info.0.model_id; + let resp: CompletionsResponse = CompletionsResponse { + choices: vec![choices], + created: created_time, + id: String::from(format!("cmpl-{}", created_time)), + object: String::from("text_completion"), + model, + usage, + }; + Json(resp.into()) +} + +pub(crate) async fn generate_to_chatcompletions( + resp: Json, + info: Extension, +) -> Json { + // let details = resp.details.as_ref().ok_or("details missing"); //; + let details = resp.details.as_ref(); + + let gen_tokens = match details { + Some(details) => details.generated_tokens, + None => 0, + }; + let finish_reason = match details { + Some(details) => Some(details.finish_reason.clone()), + None => None, + }; + let prefill_len = match details { + Some(details) => details.prefill.len() as u32, + None => 0, + }; + + let choices = ChatCompletionChoices { + message: ChatMessage { + role: ChatRole::Assistant, + content: resp.generated_text.clone(), + }, + finish_reason: finish_reason, + index: 0, + }; + let usage = Usage { + completion_tokens: gen_tokens, + total_tokens: gen_tokens + prefill_len, + prompt_tokens: prefill_len, + }; + let created_time = create_timestamp(); + let model = info.0.model_id; + let resp = ChatCompletionsResponse { + choices: vec![choices], + created: created_time, + id: String::from(format!("chatcmpl-{}", created_time)), + object: String::from("chat.completion"), + model, + usage, + }; + Json(resp.into()) +} + +pub (crate) fn create_timestamp() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time went backwards") + .as_secs() as u64 +} + +pub(crate) fn chat_start_message( + created_time: u64, + model_name: &String, +) -> ChatCompletionsStreamResponse { + let choices: ChatCompletionDeltaStreamChoices = ChatCompletionDeltaStreamChoices { + delta: ChatDeltaStreamMessage { + content: None, + role: Some(ChatRole::Assistant), + }, + finish_reason: None, + index: 0, + }; + ChatCompletionsStreamResponse { + choices: vec![choices], + created: created_time, + id: String::from(format!("chatcmpl-{}", created_time)), + object: String::from("chat.completion.chunk"), + model: model_name.to_owned(), + } +} + +pub(crate) fn create_streaming_event( + // st: StreamResponse, + stream_type: &OpenaiStreamType, + created_time: u64, + details: Option, + token: Token, + model_name: &String, +) -> Event { + match stream_type { + &OpenaiStreamType::ChatCompletionsStreamResponse => { + let choices: ChatCompletionDeltaStreamChoices = ChatCompletionDeltaStreamChoices { + delta: ChatDeltaStreamMessage { + content: Some(token.text), + role: None, + }, + finish_reason: match details { + Some(i) => Some(i.finish_reason), + None => None, + }, + index: 0, + }; + let response = ChatCompletionsStreamResponse { + choices: vec![choices], + created: created_time, + id: String::from(format!("chatcmpl-{}", created_time)), + object: String::from("chat.completion.chunk"), + model: model_name.to_owned(), + }; + Event::default().json_data(response).expect("cannot parse ChatCompletionsStreamResponse") + } + &OpenaiStreamType::CompletionsResponse => { + let choices = CompletionChoices { + text: token.text, + finish_reason: match details { + Some(i) => Some(i.finish_reason), + None => None, + }, + logprobs: None, + index: 0, + }; + + let response = CompletionsResponse { + choices: vec![choices], + created: created_time, + id: String::from(format!("cmpl-{}", created_time)), + object: String::from("text_completion"), + model: model_name.to_owned(), + usage: None, + }; + Event::default().json_data(response).expect("cannot parse streamed CompletionsResponse") + } + } +} diff --git a/router/src/lib.rs b/router/src/lib.rs index 7dff7a1..2bdf3d1 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,5 +1,21 @@ -mod health; +/// Copyright 2023 text-generation-inference contributors +/// +/// Licensed under the Apache License, Version 2.0 (the "License"); +/// you may not use this file except in compliance with the License. +/// You may obtain a copy of the License at +/// +/// http://www.apache.org/licenses/LICENSE-2.0 +/// +/// Unless required by applicable law or agreed to in writing, software +/// distributed under the License is distributed on an "AS IS" BASIS, +/// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +/// See the License for the specific language governing permissions and +/// limitations under the License. +/// /// Text Generation Inference Webserver +mod health; + +pub mod completion; mod infer; mod queue; pub mod server; @@ -211,7 +227,7 @@ pub struct Token { special: bool, } -#[derive(Serialize, ToSchema)] +#[derive(Serialize, ToSchema, Clone)] #[serde(rename_all(serialize = "snake_case"))] pub(crate) enum FinishReason { #[schema(rename = "length")] @@ -278,6 +294,11 @@ pub(crate) struct StreamResponse { pub details: Option, } +pub enum OpenaiStreamType { + ChatCompletionsStreamResponse, + CompletionsResponse, +} + #[derive(Serialize, ToSchema)] pub(crate) struct ErrorResponse { pub error: String, diff --git a/router/src/server.rs b/router/src/server.rs index 9af9495..4b5bf20 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,11 +1,33 @@ +/// Copyright 2023 text-generation-inference contributors +/// +/// Licensed under the Apache License, Version 2.0 (the "License"); +/// you may not use this file except in compliance with the License. +/// You may obtain a copy of the License at +/// +/// http://www.apache.org/licenses/LICENSE-2.0 +/// +/// Unless required by applicable law or agreed to in writing, software +/// distributed under the License is distributed on an "AS IS" BASIS, +/// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +/// See the License for the specific language governing permissions and +/// limitations under the License. +/// + /// HTTP Server logic +use crate::completion::{ + chat_start_message, chat_to_generate_request, create_streaming_event, + generate_to_chatcompletions, generate_to_completions, get_chatformatter, create_timestamp, ChatCompletionChoices, + ChatCompletionDeltaStreamChoices, ChatCompletionsResponse, ChatCompletionsStreamResponse, + ChatDeltaStreamMessage, ChatMessage, ChatRole, CompatChatCompletionRequest, + CompatCompletionRequest, CompletionChoices, CompletionsResponse, Usage, +}; use crate::health::Health; use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; use crate::{ BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason, - GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, - StreamDetails, StreamResponse, Token, Validation, + GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, + OpenaiStreamType, PrefillToken, StreamDetails, StreamResponse, Token, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -58,7 +80,7 @@ async fn compat_generate( infer: Extension, req: Json, ) -> Result)> { - let mut req = req.0; + let mut req: CompatGenerateRequest = req.0; // default return_full_text given the pipeline_tag if req.parameters.return_full_text.is_none() { @@ -77,6 +99,107 @@ async fn compat_generate( } } +/// Plain Completion request. Enable stream of token by setting `stream == true` +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/completions", + request_body = CompatCompletionRequest, + responses( + (status = 200, description = "Generated Text", + content( + ("application/json" = CompletionsResponse), + ("text/event-stream" = CompletionsResponse), + )), + (status = 424, description = "Generation Error", body = ErrorResponse, + example = json ! ({"error": "Request failed during generation"})), + (status = 429, description = "Model is overloaded", body = ErrorResponse, + example = json ! ({"error": "Model is overloaded"})), + (status = 422, description = "Input validation error", body = ErrorResponse, + example = json ! ({"error": "Input validation error"})), + (status = 500, description = "Incomplete generation", body = ErrorResponse, + example = json ! ({"error": "Incomplete generation"})), + ) + )] +#[instrument(skip(infer, req))] +async fn completions_generate( + info: Extension, + infer: Extension, + req: Json, +) -> Result)> { + let req = req.0; + + if req.stream { + Ok(generate_stream_openai( + infer, + Json(req.into()), + OpenaiStreamType::CompletionsResponse, + info.model_id.clone(), + ) + .await + .into_response()) + } else { + let (headers, generation) = generate(infer, Json(req.into())).await?; + + let generation = generate_to_completions(generation, info).await; + // wrap generation inside a Vec to match api-inference + Ok((headers, Json(generation.0)).into_response()) + } +} + +/// Chat Completion request. Enable stream of token by setting `stream == true` +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/chat/completions", + request_body = CompatChatCompletionRequest, + responses( + (status = 200, description = "Generated Text", + content( + ("application/json" = ChatCompletionsResponse), + ("text/event-stream" = ChatCompletionsStreamResponse), + )), + (status = 424, description = "Generation Error", body = ErrorResponse, + example = json ! ({"error": "Request failed during generation"})), + (status = 429, description = "Model is overloaded", body = ErrorResponse, + example = json ! ({"error": "Model is overloaded"})), + (status = 422, description = "Input validation error", body = ErrorResponse, + example = json ! ({"error": "Input validation error"})), + (status = 500, description = "Incomplete generation", body = ErrorResponse, + example = json ! ({"error": "Incomplete generation"})), + ) + )] +#[instrument(skip(infer, req))] +async fn chatcompletions_generate( + info: Extension, + infer: Extension, + req: Json, +) -> Result)> { + let stream = req.stream; + let req: CompatChatCompletionRequest = req.0; + // TODO: move this somewhere else + + let chat_formatter = get_chatformatter(); + let req: GenerateRequest = chat_to_generate_request(req, chat_formatter); + + if stream { + Ok(generate_stream_openai( + infer, + Json(req.into()), + OpenaiStreamType::ChatCompletionsStreamResponse, + info.model_id.clone(), + ) + .await + .into_response()) + } else { + let (headers, generation) = generate(infer, Json(req.into())).await?; + + let generation = generate_to_chatcompletions(generation, info).await; + // wrap generation inside a Vec to match api-inference + Ok((headers, Json(generation.0)).into_response()) + } +} + /// Text Generation Inference endpoint info #[utoipa::path( get, @@ -330,6 +453,7 @@ time_per_token, seed, ) )] + async fn generate_stream( infer: Extension, req: Json, @@ -491,6 +615,158 @@ async fn generate_stream( (headers, Sse::new(stream).keep_alive(KeepAlive::default())) } +async fn generate_stream_openai( + infer: Extension, + req: Json, + stream_type: OpenaiStreamType, + model_name: String, +) -> ( + HeaderMap, + Sse>>, +) { + let span = tracing::Span::current(); + let start_time = Instant::now(); + let created_time = create_timestamp(); + metrics::increment_counter!("tgi_request_count"); + + tracing::debug!("Input: {}", req.0.inputs); + + let compute_characters = req.0.inputs.chars().count(); + + let mut headers = HeaderMap::new(); + headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); + headers.insert( + "x-compute-characters", + compute_characters.to_string().parse().unwrap(), + ); + headers.insert("X-Accel-Buffering", "no".parse().unwrap()); + + let stream = async_stream::stream! { + // Inference + let mut end_reached = false; + let mut error = false; + + let details = req.0.parameters.details; + + let best_of = req.0.parameters.best_of.unwrap_or(1); + if best_of != 1 { + let err = InferError::from(ValidationError::BestOfStream); + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + tracing::error!("{err}"); + yield Ok(Event::from(err)); + } else if req.0.parameters.decoder_input_details { + let err = InferError::from(ValidationError::PrefillDetailsStream); + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + tracing::error!("{err}"); + yield Ok(Event::from(err)); + } else { + match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { + // Keep permit as long as generate_stream lives + Ok((_permit, mut response_stream)) => { + // Server-Sent Event stream + match stream_type { + OpenaiStreamType::ChatCompletionsStreamResponse => { + let start_msg = chat_start_message(created_time, &model_name); + yield Ok(Event::from(Event::default().json_data(start_msg).unwrap())) + }, + _ => () + }; + while let Some(response) = response_stream.next().await { + match response { + Ok(response) => { + match response { + // Prefill is ignored + InferStreamResponse::Prefill(_) => {} + // Yield event for every new token + InferStreamResponse::Token(token) => { + tracing::debug!(parent: &span, "Token: {:?}", token); + let stream_event = create_streaming_event(&stream_type, created_time, None, token, &model_name); + + yield Ok(stream_event); + } + // Yield event for last token and compute timings + InferStreamResponse::End { + token, + generated_text, + start, + queued, + } => { + // Token details + let details = match details { + true => Some(StreamDetails { + finish_reason: FinishReason::from(generated_text.finish_reason), + generated_tokens: generated_text.generated_tokens, + seed: generated_text.seed, + }), + false => None, + }; + + // Timings + let total_time = start_time.elapsed(); + let validation_time = queued - start_time; + let queue_time = start - queued; + let inference_time = Instant::now() - start; + let time_per_token = inference_time / generated_text.generated_tokens; + + // Tracing metadata + span.record("total_time", format!("{total_time:?}")); + span.record("validation_time", format!("{validation_time:?}")); + span.record("queue_time", format!("{queue_time:?}")); + span.record("inference_time", format!("{inference_time:?}")); + span.record("time_per_token", format!("{time_per_token:?}")); + span.record("seed", format!("{:?}", generated_text.seed)); + + // Metrics + metrics::increment_counter!("tgi_request_success"); + metrics::histogram!("tgi_request_duration", total_time.as_secs_f64()); + metrics::histogram!("tgi_request_validation_duration", validation_time.as_secs_f64()); + metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64()); + metrics::histogram!("tgi_request_inference_duration", inference_time.as_secs_f64()); + metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token.as_secs_f64()); + metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64); + + // create Openai StreamResponse + end_reached = true; + + tracing::debug!(parent: &span, "Output: {}", generated_text.text); + tracing::info!(parent: &span, "Success"); + + let stream_event = create_streaming_event(&stream_type, created_time, details, token, &model_name); + yield Ok(stream_event); + yield Ok(Event::default().data("[DONE]")); + break; + } + } + } + // yield error + Err(err) => { + error = true; + yield Ok(Event::from(err)); + break; + } + } + } + }, + // yield error + Err(err) => { + error = true; + yield Ok(Event::from(err)); + } + } + // Check if generation reached the end + // Skip if we already sent an error + if !end_reached && !error { + let err = InferError::IncompleteGeneration; + metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); + tracing::error!("{err}"); + yield Ok(Event::from(err)); + } + } + }; + + (headers, Sse::new(stream).keep_alive(KeepAlive::default())) +} + /// Prometheus metrics scrape endpoint #[utoipa::path( get, @@ -535,6 +811,8 @@ pub async fn run( compat_generate, generate, generate_stream, + completions_generate, + chatcompletions_generate, metrics, ), components( @@ -552,6 +830,18 @@ pub async fn run( StreamResponse, StreamDetails, ErrorResponse, + // completions messages + CompatCompletionRequest, + CompatChatCompletionRequest, + ChatMessage, + ChatRole, + CompletionsResponse, + Usage, + CompletionChoices, + ChatCompletionsResponse, + ChatCompletionChoices, + ChatCompletionsStreamResponse, + ChatDeltaStreamMessage, ChatCompletionDeltaStreamChoices, ) ), tags( @@ -672,6 +962,8 @@ pub async fn run( .route("/info", get(get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) + .route("/completions", post(completions_generate)) + .route("/chat/completions", post(chatcompletions_generate)) // AWS Sagemaker route .route("/invocations", post(compat_generate)) // Base Health route