Wrapping completions and chat/completions endpoint (#2)

* rebase and squash commits on latest main

* cargo fmt

* fix: 2038y problem

---------

Co-authored-by: michaelfeil <me@michaelfeil.eu>
This commit is contained in:
Michael Feil 2023-09-27 17:58:07 +02:00 committed by GitHub
parent f93012d59c
commit 012c917b6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 1107 additions and 6 deletions

View File

@ -39,7 +39,7 @@ RUN cargo build --release
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM debian:bullseye-slim as pytorch-install
ARG PYTORCH_VERSION=2.0.0
ARG PYTORCH_VERSION=2.0.1
ARG PYTHON_VERSION=3.9
ARG CUDA_VERSION=11.8
ARG MAMBA_VERSION=23.1.0-1

View File

@ -102,6 +102,184 @@
}
}
},
"/completions": {
"post": {
"tags": [
"Text Generation Inference"
],
"summary": "Completion request. Enable stream of token by setting `stream == true`",
"description": "Completion request. Enable stream of token by setting `stream == true`",
"operationId": "completions_generate",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/CompatCompletionRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Generated Completion",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/CompletionsResponse"
}
},
"text/event-stream": {
"schema": {
"$ref": "#/components/schemas/CompletionsResponse"
}
}
}
},
"422": {
"description": "Input validation error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Input validation error"
}
}
}
},
"424": {
"description": "Generation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Request failed during generation"
}
}
}
},
"429": {
"description": "Model is overloaded",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Model is overloaded"
}
}
}
},
"500": {
"description": "Incomplete generation",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Incomplete generation"
}
}
}
}
}
}
},
"/chat/completions": {
"post": {
"tags": [
"Text Generation Inference"
],
"summary": "Generate tokens via Chat",
"description": "Generate tokens via Chat",
"operationId": "chatcompletions_generate",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/CompatChatCompletionRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Generated Completion",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ChatCompletionsResponse"
}
},
"text/event-stream": {
"schema": {
"$ref": "#/components/schemas/ChatCompletionsStreamResponse"
}
}
}
},
"422": {
"description": "Input validation error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Input validation error"
}
}
}
},
"424": {
"description": "Generation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Request failed during generation"
}
}
}
},
"429": {
"description": "Model is overloaded",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Model is overloaded"
}
}
}
},
"500": {
"description": "Incomplete generation",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Incomplete generation"
}
}
}
}
}
}
},
"/generate": {
"post": {
"tags": [

610
router/src/completion.rs Normal file
View File

@ -0,0 +1,610 @@
/// Copyright 2023 Michael Feil, text-generation-inference contributors
///
/// Licensed under the Apache License, Version 2.0 (the "License");
/// you may not use this file except in compliance with the License.
/// You may obtain a copy of the License at
///
/// http://www.apache.org/licenses/LICENSE-2.0
///
/// Unless required by applicable law or agreed to in writing, software
/// distributed under the License is distributed on an "AS IS" BASIS,
/// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
/// See the License for the specific language governing permissions and
/// limitations under the License.
///
/// Converting generate to completions and chat/completions protocol
use crate::{
default_max_new_tokens, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
Info, OpenaiStreamType, StreamDetails, Token,
};
use axum::extract::Extension;
use axum::response::sse::Event;
use axum::Json;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct CompatCompletionRequest {
#[schema(example = "My name is Michael and I")]
pub prompt: String,
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
pub best_of: Option<usize>,
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
nullable = true,
default = "null",
example = 0.5
)]
pub temperature: Option<f32>,
#[serde(default)]
#[schema(
exclusive_minimum = -2.0,
nullable = true,
default = "null",
example = 0.0
)]
pub presence_penalty: Option<f32>,
// #[serde(default)]
// #[schema(exclusive_minimum = 0, nullable = true, default = 1, example = 1)]
// pub n: Option<i32>,
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
pub top_k: Option<i32>,
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
maximum = 1.0,
nullable = true,
default = "null",
example = 0.95
)]
pub top_p: Option<f32>,
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
maximum = 1.0,
nullable = true,
default = "null",
example = 0.95
)]
pub typical_p: Option<f32>,
#[serde(default)]
#[schema(default = "false", example = true)]
pub do_sample: bool,
#[serde(default = "default_max_new_tokens")]
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
pub max_tokens: u32,
#[serde(default)]
#[schema(nullable = true, default = "null", example = false)]
pub echo: Option<bool>,
#[serde(default)]
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
pub stop: Vec<String>,
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub truncate: Option<usize>,
#[serde(default)]
#[schema(default = "false", example = true)]
pub watermark: bool,
#[serde(default)]
#[schema(default = "false")]
pub decoder_input_details: bool,
#[serde(default)]
#[schema(
exclusive_minimum = 0,
nullable = true,
default = "null",
example = "null"
)]
pub seed: Option<u64>,
#[serde(default)]
#[schema(default = "false")]
pub stream: bool,
}
impl From<CompatCompletionRequest> for GenerateRequest {
fn from(req: CompatCompletionRequest) -> Self {
let presence_penalty = req.presence_penalty;
let presence_penalty = match presence_penalty {
Some(presence_penalty) => Some((presence_penalty + 2.0) / 2.0),
None => None,
};
Self {
inputs: req.prompt,
parameters: GenerateParameters {
best_of: req.best_of,
temperature: req.temperature,
repetition_penalty: presence_penalty,
top_k: req.top_k,
top_p: req.top_p,
typical_p: req.typical_p,
do_sample: req.do_sample,
max_new_tokens: req.max_tokens,
return_full_text: req.echo,
stop: req.stop,
truncate: req.truncate,
watermark: req.watermark,
details: true,
decoder_input_details: req.decoder_input_details,
seed: req.seed,
},
}
}
}
#[derive(Clone, Debug, ToSchema, Deserialize, Serialize)]
pub(crate) enum ChatRole {
#[serde(rename = "user")]
User,
#[serde(rename = "assistant")]
Assistant,
#[serde(rename = "system")]
System,
}
#[derive(Clone, Debug, Serialize, ToSchema)]
pub(crate) struct ChatFormatterPrePost {
pre: String,
post: String,
}
#[derive(Clone, Debug, Serialize, ToSchema)]
pub(crate) struct ChatFormatter {
user_template: ChatFormatterPrePost,
assistant_template: ChatFormatterPrePost,
system_template: ChatFormatterPrePost,
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatMessage {
#[schema(example = "user")]
role: ChatRole,
#[schema(example = "What is the capital of Bavaria?")]
content: String,
// user: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatDeltaStreamMessage {
#[schema(example = "user")]
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<ChatRole>,
#[schema(example = "What is the capital of Bavaria?")]
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
// user: Option<String>,
}
#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct CompatChatCompletionRequest {
pub messages: Vec<ChatMessage>,
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
pub best_of: Option<usize>,
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
nullable = true,
default = "null",
example = 0.5
)]
pub temperature: Option<f32>,
#[serde(default)]
#[schema(
exclusive_minimum = -2.0,
nullable = true,
default = "null",
example = 0.0
)]
pub presence_penalty: Option<f32>,
// #[serde(default)]
// #[schema(exclusive_minimum = 0, nullable = true, default = 1, example = 1)]
// pub n: Option<u32>,
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
pub top_k: Option<i32>,
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
maximum = 1.0,
nullable = true,
default = "null",
example = 0.95
)]
pub top_p: Option<f32>,
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
maximum = 1.0,
nullable = true,
default = "null",
example = 0.95
)]
pub typical_p: Option<f32>,
#[serde(default)]
#[schema(default = "false", example = true)]
pub do_sample: bool,
#[serde(default = "default_max_new_tokens")]
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
pub max_tokens: u32,
#[serde(default)]
#[schema(nullable = true, default = "null", example = false)]
pub echo: Option<bool>,
#[serde(default)]
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
pub stop: Vec<String>,
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub truncate: Option<usize>,
#[serde(default)]
#[schema(default = "false", example = true)]
pub watermark: bool,
#[serde(default)]
#[schema(default = "false")]
pub decoder_input_details: bool,
#[serde(default)]
#[schema(
exclusive_minimum = 0,
nullable = true,
default = "null",
example = "null"
)]
pub seed: Option<u64>,
#[serde(default)]
#[schema(default = "false")]
pub stream: bool,
// #[serde(default)]
// #[schema(nullable = true, default = "null", example = "null")]
// pub user: Option<String>,
}
pub(crate) fn chat_to_generate_request(
req: CompatChatCompletionRequest,
formatter: ChatFormatter,
) -> GenerateRequest {
let mut prompt = String::from("");
for m in req.messages {
// let role = m.role
let template = match m.role {
ChatRole::Assistant => &formatter.assistant_template,
ChatRole::System => &formatter.system_template,
ChatRole::User => &formatter.user_template,
};
prompt.push_str(&template.pre);
prompt.push_str(&m.content);
prompt.push_str(&template.post);
}
let presence_penalty = req.presence_penalty;
let presence_penalty = match presence_penalty {
Some(presence_penalty) => Some((presence_penalty + 2.0) / 2.0),
None => None,
};
GenerateRequest {
inputs: prompt,
parameters: GenerateParameters {
best_of: req.best_of,
temperature: req.temperature,
repetition_penalty: presence_penalty,
top_k: req.top_k,
top_p: req.top_p,
typical_p: req.typical_p,
do_sample: req.do_sample,
max_new_tokens: req.max_tokens,
return_full_text: req.echo,
stop: req.stop,
truncate: req.truncate,
watermark: req.watermark,
details: true,
decoder_input_details: req.decoder_input_details,
seed: req.seed,
},
}
}
#[derive(Serialize, ToSchema)]
pub(crate) struct Usage {
#[schema(example = 1)]
pub total_tokens: u32,
#[schema(example = 1)]
pub completion_tokens: u32,
#[schema(example = 1)]
pub prompt_tokens: u32,
}
#[derive(Serialize, ToSchema)]
pub(crate) struct CompletionChoices {
#[schema(example = "test")]
pub text: String,
#[schema(example = "length")]
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<FinishReason>,
// pub generated_tokens: u32,
// logprobs is not implemented, send None
pub logprobs: Option<Vec<u32>>,
#[schema(example = 0)]
pub index: u32,
}
#[derive(Serialize, ToSchema)]
pub(crate) struct CompletionsResponse {
#[schema(example = "cmpl-abcdefgehij1234")]
pub id: String,
#[schema(example = "text_completion")]
pub object: String,
#[schema(example = 1589478379)]
pub created: u64,
#[schema(example = "tgi")]
pub model: String,
pub choices: Vec<CompletionChoices>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
}
#[derive(Serialize, ToSchema)]
pub(crate) struct ChatCompletionChoices {
#[schema(example = "test")]
pub message: ChatMessage,
#[schema(example = "length")]
pub finish_reason: Option<FinishReason>,
// pub generated_tokens: u32,
#[schema(example = 0)]
pub index: u32,
}
#[derive(Serialize, ToSchema)]
pub(crate) struct ChatCompletionDeltaStreamChoices {
#[schema(example = "test")]
pub delta: ChatDeltaStreamMessage,
#[schema(example = "length")]
pub finish_reason: Option<FinishReason>,
// pub generated_tokens: u32,
#[schema(example = 0)]
pub index: u32,
}
#[derive(Serialize, ToSchema)]
pub(crate) struct ChatCompletionsResponse {
#[schema(example = "chatcmpl-abcdefgehij1234")]
pub id: String,
#[schema(example = "chat.completion")]
pub object: String,
#[schema(example = 1589478380)]
pub created: u64,
#[schema(example = "tgi")]
pub model: String,
pub choices: Vec<ChatCompletionChoices>,
pub usage: Usage,
}
#[derive(Serialize, ToSchema)]
pub(crate) struct ChatCompletionsStreamResponse {
#[schema(example = "chatcmpl-abcdefgehij1234")]
pub id: String,
#[schema(example = "chat.completion.chunk")]
pub object: String,
#[schema(example = 1589478380)]
pub created: u64,
#[schema(example = "tgi")]
pub model: String,
pub choices: Vec<ChatCompletionDeltaStreamChoices>,
}
pub(crate) fn get_chatformatter() -> ChatFormatter {
// TODO: improve reading this, e.g. at startup once from a chat_config.json
let chat_user_pre: String = match std::env::var_os("TGICHAT_USER_PRE") {
Some(v) => v.into_string().unwrap(),
None => String::from(""),
};
let chat_user_post: String = match std::env::var_os("TGICHAT_USER_POST") {
Some(v) => v.into_string().unwrap(),
None => String::from(""),
};
let chat_ass_pre: String = match std::env::var_os("TGICHAT_ASS_PRE") {
Some(v) => v.into_string().unwrap(),
None => String::from(""),
};
let chat_ass_post: String = match std::env::var_os("TGICHAT_ASS_POST") {
Some(v) => v.into_string().unwrap(),
None => String::from(""),
};
let chat_sys_pre: String = match std::env::var_os("TGICHAT_SYS_PRE") {
Some(v) => v.into_string().unwrap(),
None => String::from(""),
};
let chat_sys_post: String = match std::env::var_os("TGICHAT_SYS_POST") {
Some(v) => v.into_string().unwrap(),
None => String::from(""),
};
ChatFormatter {
user_template: ChatFormatterPrePost {
pre: chat_user_pre,
post: chat_user_post,
},
assistant_template: ChatFormatterPrePost {
pre: chat_ass_pre,
post: chat_ass_post,
},
system_template: ChatFormatterPrePost {
pre: chat_sys_pre,
post: chat_sys_post,
},
}
}
pub(crate) async fn generate_to_completions(
resp: Json<GenerateResponse>,
info: Extension<Info>,
) -> Json<CompletionsResponse> {
// let details = resp.details.as_ref().ok_or("details missing"); //;
let details = resp.details.as_ref();
let gen_tokens = match details {
Some(details) => details.generated_tokens,
None => 0,
};
let finish_reason = match details {
Some(details) => Some(details.finish_reason.clone()),
None => None,
};
let prefill_len = match details {
Some(details) => details.prefill.len() as u32,
None => 0,
};
let choices = CompletionChoices {
text: resp.generated_text.clone(),
finish_reason: finish_reason,
logprobs: None,
index: 0,
};
let usage = Some(Usage {
completion_tokens: gen_tokens,
total_tokens: gen_tokens + prefill_len,
prompt_tokens: prefill_len,
});
let created_time = create_timestamp();
let model = info.0.model_id;
let resp: CompletionsResponse = CompletionsResponse {
choices: vec![choices],
created: created_time,
id: String::from(format!("cmpl-{}", created_time)),
object: String::from("text_completion"),
model,
usage,
};
Json(resp.into())
}
pub(crate) async fn generate_to_chatcompletions(
resp: Json<GenerateResponse>,
info: Extension<Info>,
) -> Json<ChatCompletionsResponse> {
// let details = resp.details.as_ref().ok_or("details missing"); //;
let details = resp.details.as_ref();
let gen_tokens = match details {
Some(details) => details.generated_tokens,
None => 0,
};
let finish_reason = match details {
Some(details) => Some(details.finish_reason.clone()),
None => None,
};
let prefill_len = match details {
Some(details) => details.prefill.len() as u32,
None => 0,
};
let choices = ChatCompletionChoices {
message: ChatMessage {
role: ChatRole::Assistant,
content: resp.generated_text.clone(),
},
finish_reason: finish_reason,
index: 0,
};
let usage = Usage {
completion_tokens: gen_tokens,
total_tokens: gen_tokens + prefill_len,
prompt_tokens: prefill_len,
};
let created_time = create_timestamp();
let model = info.0.model_id;
let resp = ChatCompletionsResponse {
choices: vec![choices],
created: created_time,
id: String::from(format!("chatcmpl-{}", created_time)),
object: String::from("chat.completion"),
model,
usage,
};
Json(resp.into())
}
pub (crate) fn create_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time went backwards")
.as_secs() as u64
}
pub(crate) fn chat_start_message(
created_time: u64,
model_name: &String,
) -> ChatCompletionsStreamResponse {
let choices: ChatCompletionDeltaStreamChoices = ChatCompletionDeltaStreamChoices {
delta: ChatDeltaStreamMessage {
content: None,
role: Some(ChatRole::Assistant),
},
finish_reason: None,
index: 0,
};
ChatCompletionsStreamResponse {
choices: vec![choices],
created: created_time,
id: String::from(format!("chatcmpl-{}", created_time)),
object: String::from("chat.completion.chunk"),
model: model_name.to_owned(),
}
}
pub(crate) fn create_streaming_event(
// st: StreamResponse,
stream_type: &OpenaiStreamType,
created_time: u64,
details: Option<StreamDetails>,
token: Token,
model_name: &String,
) -> Event {
match stream_type {
&OpenaiStreamType::ChatCompletionsStreamResponse => {
let choices: ChatCompletionDeltaStreamChoices = ChatCompletionDeltaStreamChoices {
delta: ChatDeltaStreamMessage {
content: Some(token.text),
role: None,
},
finish_reason: match details {
Some(i) => Some(i.finish_reason),
None => None,
},
index: 0,
};
let response = ChatCompletionsStreamResponse {
choices: vec![choices],
created: created_time,
id: String::from(format!("chatcmpl-{}", created_time)),
object: String::from("chat.completion.chunk"),
model: model_name.to_owned(),
};
Event::default().json_data(response).expect("cannot parse ChatCompletionsStreamResponse")
}
&OpenaiStreamType::CompletionsResponse => {
let choices = CompletionChoices {
text: token.text,
finish_reason: match details {
Some(i) => Some(i.finish_reason),
None => None,
},
logprobs: None,
index: 0,
};
let response = CompletionsResponse {
choices: vec![choices],
created: created_time,
id: String::from(format!("cmpl-{}", created_time)),
object: String::from("text_completion"),
model: model_name.to_owned(),
usage: None,
};
Event::default().json_data(response).expect("cannot parse streamed CompletionsResponse")
}
}
}

View File

@ -1,5 +1,21 @@
mod health;
/// Copyright 2023 text-generation-inference contributors
///
/// Licensed under the Apache License, Version 2.0 (the "License");
/// you may not use this file except in compliance with the License.
/// You may obtain a copy of the License at
///
/// http://www.apache.org/licenses/LICENSE-2.0
///
/// Unless required by applicable law or agreed to in writing, software
/// distributed under the License is distributed on an "AS IS" BASIS,
/// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
/// See the License for the specific language governing permissions and
/// limitations under the License.
///
/// Text Generation Inference Webserver
mod health;
pub mod completion;
mod infer;
mod queue;
pub mod server;
@ -211,7 +227,7 @@ pub struct Token {
special: bool,
}
#[derive(Serialize, ToSchema)]
#[derive(Serialize, ToSchema, Clone)]
#[serde(rename_all(serialize = "snake_case"))]
pub(crate) enum FinishReason {
#[schema(rename = "length")]
@ -278,6 +294,11 @@ pub(crate) struct StreamResponse {
pub details: Option<StreamDetails>,
}
pub enum OpenaiStreamType {
ChatCompletionsStreamResponse,
CompletionsResponse,
}
#[derive(Serialize, ToSchema)]
pub(crate) struct ErrorResponse {
pub error: String,

View File

@ -1,11 +1,33 @@
/// Copyright 2023 text-generation-inference contributors
///
/// Licensed under the Apache License, Version 2.0 (the "License");
/// you may not use this file except in compliance with the License.
/// You may obtain a copy of the License at
///
/// http://www.apache.org/licenses/LICENSE-2.0
///
/// Unless required by applicable law or agreed to in writing, software
/// distributed under the License is distributed on an "AS IS" BASIS,
/// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
/// See the License for the specific language governing permissions and
/// limitations under the License.
///
/// HTTP Server logic
use crate::completion::{
chat_start_message, chat_to_generate_request, create_streaming_event,
generate_to_chatcompletions, generate_to_completions, get_chatformatter, create_timestamp, ChatCompletionChoices,
ChatCompletionDeltaStreamChoices, ChatCompletionsResponse, ChatCompletionsStreamResponse,
ChatDeltaStreamMessage, ChatMessage, ChatRole, CompatChatCompletionRequest,
CompatCompletionRequest, CompletionChoices, CompletionsResponse, Usage,
};
use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken,
StreamDetails, StreamResponse, Token, Validation,
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info,
OpenaiStreamType, PrefillToken, StreamDetails, StreamResponse, Token, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
@ -58,7 +80,7 @@ async fn compat_generate(
infer: Extension<Infer>,
req: Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let mut req = req.0;
let mut req: CompatGenerateRequest = req.0;
// default return_full_text given the pipeline_tag
if req.parameters.return_full_text.is_none() {
@ -77,6 +99,107 @@ async fn compat_generate(
}
}
/// Plain Completion request. Enable stream of token by setting `stream == true`
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/completions",
request_body = CompatCompletionRequest,
responses(
(status = 200, description = "Generated Text",
content(
("application/json" = CompletionsResponse),
("text/event-stream" = CompletionsResponse),
)),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(skip(infer, req))]
async fn completions_generate(
info: Extension<Info>,
infer: Extension<Infer>,
req: Json<CompatCompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let req = req.0;
if req.stream {
Ok(generate_stream_openai(
infer,
Json(req.into()),
OpenaiStreamType::CompletionsResponse,
info.model_id.clone(),
)
.await
.into_response())
} else {
let (headers, generation) = generate(infer, Json(req.into())).await?;
let generation = generate_to_completions(generation, info).await;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(generation.0)).into_response())
}
}
/// Chat Completion request. Enable stream of token by setting `stream == true`
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/chat/completions",
request_body = CompatChatCompletionRequest,
responses(
(status = 200, description = "Generated Text",
content(
("application/json" = ChatCompletionsResponse),
("text/event-stream" = ChatCompletionsStreamResponse),
)),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(skip(infer, req))]
async fn chatcompletions_generate(
info: Extension<Info>,
infer: Extension<Infer>,
req: Json<CompatChatCompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let stream = req.stream;
let req: CompatChatCompletionRequest = req.0;
// TODO: move this somewhere else
let chat_formatter = get_chatformatter();
let req: GenerateRequest = chat_to_generate_request(req, chat_formatter);
if stream {
Ok(generate_stream_openai(
infer,
Json(req.into()),
OpenaiStreamType::ChatCompletionsStreamResponse,
info.model_id.clone(),
)
.await
.into_response())
} else {
let (headers, generation) = generate(infer, Json(req.into())).await?;
let generation = generate_to_chatcompletions(generation, info).await;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(generation.0)).into_response())
}
}
/// Text Generation Inference endpoint info
#[utoipa::path(
get,
@ -330,6 +453,7 @@ time_per_token,
seed,
)
)]
async fn generate_stream(
infer: Extension<Infer>,
req: Json<GenerateRequest>,
@ -491,6 +615,158 @@ async fn generate_stream(
(headers, Sse::new(stream).keep_alive(KeepAlive::default()))
}
async fn generate_stream_openai(
infer: Extension<Infer>,
req: Json<GenerateRequest>,
stream_type: OpenaiStreamType,
model_name: String,
) -> (
HeaderMap,
Sse<impl Stream<Item = Result<Event, Infallible>>>,
) {
let span = tracing::Span::current();
let start_time = Instant::now();
let created_time = create_timestamp();
metrics::increment_counter!("tgi_request_count");
tracing::debug!("Input: {}", req.0.inputs);
let compute_characters = req.0.inputs.chars().count();
let mut headers = HeaderMap::new();
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
headers.insert(
"x-compute-characters",
compute_characters.to_string().parse().unwrap(),
);
headers.insert("X-Accel-Buffering", "no".parse().unwrap());
let stream = async_stream::stream! {
// Inference
let mut end_reached = false;
let mut error = false;
let details = req.0.parameters.details;
let best_of = req.0.parameters.best_of.unwrap_or(1);
if best_of != 1 {
let err = InferError::from(ValidationError::BestOfStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
} else if req.0.parameters.decoder_input_details {
let err = InferError::from(ValidationError::PrefillDetailsStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
} else {
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => {
// Server-Sent Event stream
match stream_type {
OpenaiStreamType::ChatCompletionsStreamResponse => {
let start_msg = chat_start_message(created_time, &model_name);
yield Ok(Event::from(Event::default().json_data(start_msg).unwrap()))
},
_ => ()
};
while let Some(response) = response_stream.next().await {
match response {
Ok(response) => {
match response {
// Prefill is ignored
InferStreamResponse::Prefill(_) => {}
// Yield event for every new token
InferStreamResponse::Token(token) => {
tracing::debug!(parent: &span, "Token: {:?}", token);
let stream_event = create_streaming_event(&stream_type, created_time, None, token, &model_name);
yield Ok(stream_event);
}
// Yield event for last token and compute timings
InferStreamResponse::End {
token,
generated_text,
start,
queued,
} => {
// Token details
let details = match details {
true => Some(StreamDetails {
finish_reason: FinishReason::from(generated_text.finish_reason),
generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed,
}),
false => None,
};
// Timings
let total_time = start_time.elapsed();
let validation_time = queued - start_time;
let queue_time = start - queued;
let inference_time = Instant::now() - start;
let time_per_token = inference_time / generated_text.generated_tokens;
// Tracing metadata
span.record("total_time", format!("{total_time:?}"));
span.record("validation_time", format!("{validation_time:?}"));
span.record("queue_time", format!("{queue_time:?}"));
span.record("inference_time", format!("{inference_time:?}"));
span.record("time_per_token", format!("{time_per_token:?}"));
span.record("seed", format!("{:?}", generated_text.seed));
// Metrics
metrics::increment_counter!("tgi_request_success");
metrics::histogram!("tgi_request_duration", total_time.as_secs_f64());
metrics::histogram!("tgi_request_validation_duration", validation_time.as_secs_f64());
metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64());
metrics::histogram!("tgi_request_inference_duration", inference_time.as_secs_f64());
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token.as_secs_f64());
metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);
// create Openai StreamResponse
end_reached = true;
tracing::debug!(parent: &span, "Output: {}", generated_text.text);
tracing::info!(parent: &span, "Success");
let stream_event = create_streaming_event(&stream_type, created_time, details, token, &model_name);
yield Ok(stream_event);
yield Ok(Event::default().data("[DONE]"));
break;
}
}
}
// yield error
Err(err) => {
error = true;
yield Ok(Event::from(err));
break;
}
}
}
},
// yield error
Err(err) => {
error = true;
yield Ok(Event::from(err));
}
}
// Check if generation reached the end
// Skip if we already sent an error
if !end_reached && !error {
let err = InferError::IncompleteGeneration;
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
tracing::error!("{err}");
yield Ok(Event::from(err));
}
}
};
(headers, Sse::new(stream).keep_alive(KeepAlive::default()))
}
/// Prometheus metrics scrape endpoint
#[utoipa::path(
get,
@ -535,6 +811,8 @@ pub async fn run(
compat_generate,
generate,
generate_stream,
completions_generate,
chatcompletions_generate,
metrics,
),
components(
@ -552,6 +830,18 @@ pub async fn run(
StreamResponse,
StreamDetails,
ErrorResponse,
// completions messages
CompatCompletionRequest,
CompatChatCompletionRequest,
ChatMessage,
ChatRole,
CompletionsResponse,
Usage,
CompletionChoices,
ChatCompletionsResponse,
ChatCompletionChoices,
ChatCompletionsStreamResponse,
ChatDeltaStreamMessage, ChatCompletionDeltaStreamChoices,
)
),
tags(
@ -672,6 +962,8 @@ pub async fn run(
.route("/info", get(get_model_info))
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
.route("/completions", post(completions_generate))
.route("/chat/completions", post(chatcompletions_generate))
// AWS Sagemaker route
.route("/invocations", post(compat_generate))
// Base Health route