feat: accept legacy request format and response (#1527)

This WIP PR (will) add support for legacy OpenAI `v1/completions` API.

This should allow TGI to be a drop in replacement for OpenAI when using
tools that rely on the completions api

Should fix:
https://github.com/huggingface/text-generation-inference/issues/1468
This commit is contained in:
drbh 2024-02-29 10:44:20 -05:00 committed by GitHub
parent 9ed4d2c780
commit 3dd7da2198
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 286 additions and 9 deletions

View File

@ -231,6 +231,9 @@ class Client:
Return the decoder input token logprobs and ids Return the decoder input token logprobs and ids
top_n_tokens (`int`): top_n_tokens (`int`):
Return the `n` most likely tokens at each step Return the `n` most likely tokens at each step
grammar (`Grammar`):
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
of the text to match a regular expression or JSON schema.
Returns: Returns:
Response: generated response Response: generated response
@ -322,6 +325,9 @@ class Client:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
top_n_tokens (`int`): top_n_tokens (`int`):
Return the `n` most likely tokens at each step Return the `n` most likely tokens at each step
grammar (`Grammar`):
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
of the text to match a regular expression or JSON schema.
Returns: Returns:
Iterator[StreamResponse]: stream of generated tokens Iterator[StreamResponse]: stream of generated tokens
@ -592,6 +598,9 @@ class AsyncClient:
Return the decoder input token logprobs and ids Return the decoder input token logprobs and ids
top_n_tokens (`int`): top_n_tokens (`int`):
Return the `n` most likely tokens at each step Return the `n` most likely tokens at each step
grammar (`Grammar`):
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
of the text to match a regular expression or JSON schema.
Returns: Returns:
Response: generated response Response: generated response
@ -682,6 +691,9 @@ class AsyncClient:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
top_n_tokens (`int`): top_n_tokens (`int`):
Return the `n` most likely tokens at each step Return the `n` most likely tokens at each step
grammar (`Grammar`):
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
of the text to match a regular expression or JSON schema.
Returns: Returns:
AsyncIterator[StreamResponse]: stream of generated tokens AsyncIterator[StreamResponse]: stream of generated tokens

View File

@ -51,6 +51,7 @@ pub struct HubModelInfo {
#[derive(Clone, Deserialize, Default)] #[derive(Clone, Deserialize, Default)]
pub struct HubTokenizerConfig { pub struct HubTokenizerConfig {
pub chat_template: Option<String>, pub chat_template: Option<String>,
pub completion_template: Option<String>,
#[serde(deserialize_with = "token_serde::deserialize")] #[serde(deserialize_with = "token_serde::deserialize")]
pub bos_token: Option<String>, pub bos_token: Option<String>,
#[serde(deserialize_with = "token_serde::deserialize")] #[serde(deserialize_with = "token_serde::deserialize")]
@ -265,6 +266,76 @@ fn default_parameters() -> GenerateParameters {
} }
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
pub struct CompletionRequest {
/// UNUSED
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String,
/// The prompt to generate completions for.
#[schema(example = "What is Deep Learning?")]
pub prompt: String,
/// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)]
#[schema(default = "32")]
pub max_tokens: Option<u32>,
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
/// lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or `top_p` but not both.
#[serde(default)]
#[schema(nullable = true, example = 1.0)]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
#[serde(default)]
#[schema(nullable = true, example = 0.95)]
pub top_p: Option<f32>,
#[serde(default = "bool::default")]
pub stream: bool,
#[schema(nullable = true, example = 42)]
pub seed: Option<u64>,
/// The text to append to the prompt. This is useful for completing sentences or generating a paragraph of text.
/// please see the completion_template field in the model's tokenizer_config.json file for completion template.
#[serde(default)]
pub suffix: Option<String>,
#[serde(default)]
pub repetition_penalty: Option<f32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim.
#[serde(default)]
#[schema(example = "1.0")]
pub frequency_penalty: Option<f32>,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
pub(crate) struct Completion {
pub id: String,
pub object: String,
#[schema(example = "1706270835")]
pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
pub model: String,
pub system_fingerprint: String,
pub choices: Vec<CompletionComplete>,
pub usage: Usage,
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct CompletionComplete {
pub index: u32,
pub text: String,
pub logprobs: Option<Vec<f32>>,
pub finish_reason: String,
}
#[derive(Clone, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletion { pub(crate) struct ChatCompletion {
pub id: String, pub id: String,
@ -347,7 +418,7 @@ pub(crate) struct ChatCompletionTopLogprob {
logprob: f32, logprob: f32,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
pub(crate) struct Usage { pub(crate) struct Usage {
pub prompt_tokens: u32, pub prompt_tokens: u32,
pub completion_tokens: u32, pub completion_tokens: u32,
@ -390,7 +461,15 @@ impl ChatCompletion {
} }
} }
} }
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct CompletionCompleteChunk {
pub id: String,
pub object: String,
pub created: u64,
pub choices: Vec<CompletionComplete>,
pub model: String,
pub system_fingerprint: String,
}
#[derive(Clone, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChunk { pub(crate) struct ChatCompletionChunk {
pub id: String, pub id: String,

View File

@ -3,12 +3,16 @@ use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, GenerateResponse, GrammarType, HubModelInfo, HubTokenizerConfig, Infer, Info, Message,
ChatCompletionTopLogprob, ChatRequest, CompatGenerateRequest, Details, ErrorResponse, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage,
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, Validation,
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, };
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse, use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
CompletionRequest, VertexRequest, VertexResponse,
}; };
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools}; use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
use axum::extract::Extension; use axum::extract::Extension;
@ -536,6 +540,183 @@ async fn generate_stream_internal(
(headers, stream) (headers, stream)
} }
/// Generate tokens
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/v1/completions",
request_body = CompletionRequest,
responses(
(status = 200, description = "Generated Text", body = ChatCompletionChunk),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(
skip_all,
fields(
// parameters = ? req.parameters,
total_time,
validation_time,
queue_time,
inference_time,
time_per_token,
seed,
)
)]
async fn completions(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>,
Json(req): Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
metrics::increment_counter!("tgi_request_count");
let stream = req.stream;
let max_new_tokens = req.max_tokens.or(Some(100));
let seed = req.seed;
// if suffix is present throw an error
if req.suffix.is_some() {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Suffix is not supported and can be achieved by preprocessing the prompt."
.to_string(),
error_type: "suffix not supported".to_string(),
}),
));
}
// build the request passing some parameters
let generate_request = GenerateRequest {
inputs: req.prompt.to_string(),
parameters: GenerateParameters {
best_of: None,
temperature: req.temperature,
repetition_penalty: req.repetition_penalty,
frequency_penalty: req.frequency_penalty,
top_k: None,
top_p: req.top_p,
typical_p: None,
do_sample: true,
max_new_tokens,
return_full_text: None,
stop: Vec::new(),
truncate: None,
watermark: false,
details: true,
decoder_input_details: !stream,
seed,
top_n_tokens: None,
grammar: None,
},
};
if stream {
let on_message_callback = move |stream_token: StreamResponse| {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
event
.json_data(CompletionCompleteChunk {
id: "".to_string(),
object: "text_completion".to_string(),
created: current_time,
choices: vec![CompletionComplete {
finish_reason: "".to_string(),
index: 0,
logprobs: None,
text: stream_token.token.text,
}],
model: info.model_id.clone(),
system_fingerprint: format!(
"{}-{}",
info.version,
info.docker_label.unwrap_or("native")
),
})
.map_or_else(
|e| {
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
Event::default()
},
|data| data,
)
};
let (headers, response_stream) = generate_stream_internal(
infer,
compute_type,
Json(generate_request),
on_message_callback,
)
.await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response())
} else {
let (headers, Json(generation)) = generate(
Extension(infer),
Extension(compute_type),
Json(generate_request),
)
.await?;
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let details = generation.details.ok_or((
// this should never happen but handle if details are missing unexpectedly
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "No details in generation".to_string(),
error_type: "no details".to_string(),
}),
))?;
let response = Completion {
id: "".to_string(),
object: "text_completion".to_string(),
created: current_time,
model: info.model_id.clone(),
system_fingerprint: format!(
"{}-{}",
info.version,
info.docker_label.unwrap_or("native")
),
choices: vec![CompletionComplete {
finish_reason: details.finish_reason.to_string(),
index: 0,
logprobs: None,
text: generation.generated_text,
}],
usage: Usage {
prompt_tokens: details.prefill.len() as u32,
completion_tokens: details.generated_tokens,
total_tokens: details.prefill.len() as u32 + details.generated_tokens,
},
};
Ok((headers, Json(response)).into_response())
}
}
/// Generate tokens /// Generate tokens
#[utoipa::path( #[utoipa::path(
post, post,
@ -993,6 +1174,7 @@ pub async fn run(
generate, generate,
generate_stream, generate_stream,
chat_completions, chat_completions,
completions,
tokenize, tokenize,
metrics, metrics,
), ),
@ -1012,6 +1194,9 @@ pub async fn run(
ChatCompletionLogprobs, ChatCompletionLogprobs,
ChatCompletionTopLogprob, ChatCompletionTopLogprob,
ChatCompletion, ChatCompletion,
CompletionRequest,
CompletionComplete,
CompletionCompleteChunk,
GenerateParameters, GenerateParameters,
PrefillToken, PrefillToken,
Token, Token,
@ -1184,6 +1369,7 @@ pub async fn run(
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
.route("/v1/chat/completions", post(chat_completions)) .route("/v1/chat/completions", post(chat_completions))
.route("/v1/completions", post(completions))
.route("/vertex", post(vertex_compatibility)) .route("/vertex", post(vertex_compatibility))
.route("/tokenize", post(tokenize)) .route("/tokenize", post(tokenize))
.route("/health", get(health)) .route("/health", get(health))

View File

@ -36,7 +36,6 @@ __all__ = [
"Model", "Model",
"BLOOMSharded", "BLOOMSharded",
"CausalLM", "CausalLM",
"FlashCausalLM",
"GalacticaSharded", "GalacticaSharded",
"Seq2SeqLM", "Seq2SeqLM",
"SantaCoder", "SantaCoder",
@ -48,6 +47,7 @@ __all__ = [
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = True FLASH_ATTENTION = True
try: try:
from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_neox import FlashNeoXSharded from text_generation_server.models.flash_neox import FlashNeoXSharded