feat: accept legacy request format and response (#1527)
This WIP PR (will) add support for legacy OpenAI `v1/completions` API. This should allow TGI to be a drop in replacement for OpenAI when using tools that rely on the completions api Should fix: https://github.com/huggingface/text-generation-inference/issues/1468
This commit is contained in:
parent
9ed4d2c780
commit
3dd7da2198
|
@ -231,6 +231,9 @@ class Client:
|
||||||
Return the decoder input token logprobs and ids
|
Return the decoder input token logprobs and ids
|
||||||
top_n_tokens (`int`):
|
top_n_tokens (`int`):
|
||||||
Return the `n` most likely tokens at each step
|
Return the `n` most likely tokens at each step
|
||||||
|
grammar (`Grammar`):
|
||||||
|
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
|
||||||
|
of the text to match a regular expression or JSON schema.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Response: generated response
|
Response: generated response
|
||||||
|
@ -322,6 +325,9 @@ class Client:
|
||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
top_n_tokens (`int`):
|
top_n_tokens (`int`):
|
||||||
Return the `n` most likely tokens at each step
|
Return the `n` most likely tokens at each step
|
||||||
|
grammar (`Grammar`):
|
||||||
|
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
|
||||||
|
of the text to match a regular expression or JSON schema.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Iterator[StreamResponse]: stream of generated tokens
|
Iterator[StreamResponse]: stream of generated tokens
|
||||||
|
@ -592,6 +598,9 @@ class AsyncClient:
|
||||||
Return the decoder input token logprobs and ids
|
Return the decoder input token logprobs and ids
|
||||||
top_n_tokens (`int`):
|
top_n_tokens (`int`):
|
||||||
Return the `n` most likely tokens at each step
|
Return the `n` most likely tokens at each step
|
||||||
|
grammar (`Grammar`):
|
||||||
|
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
|
||||||
|
of the text to match a regular expression or JSON schema.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Response: generated response
|
Response: generated response
|
||||||
|
@ -682,6 +691,9 @@ class AsyncClient:
|
||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
top_n_tokens (`int`):
|
top_n_tokens (`int`):
|
||||||
Return the `n` most likely tokens at each step
|
Return the `n` most likely tokens at each step
|
||||||
|
grammar (`Grammar`):
|
||||||
|
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
|
||||||
|
of the text to match a regular expression or JSON schema.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AsyncIterator[StreamResponse]: stream of generated tokens
|
AsyncIterator[StreamResponse]: stream of generated tokens
|
||||||
|
|
|
@ -51,6 +51,7 @@ pub struct HubModelInfo {
|
||||||
#[derive(Clone, Deserialize, Default)]
|
#[derive(Clone, Deserialize, Default)]
|
||||||
pub struct HubTokenizerConfig {
|
pub struct HubTokenizerConfig {
|
||||||
pub chat_template: Option<String>,
|
pub chat_template: Option<String>,
|
||||||
|
pub completion_template: Option<String>,
|
||||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||||
pub bos_token: Option<String>,
|
pub bos_token: Option<String>,
|
||||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||||
|
@ -265,6 +266,76 @@ fn default_parameters() -> GenerateParameters {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||||
|
pub struct CompletionRequest {
|
||||||
|
/// UNUSED
|
||||||
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
|
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||||||
|
pub model: String,
|
||||||
|
|
||||||
|
/// The prompt to generate completions for.
|
||||||
|
#[schema(example = "What is Deep Learning?")]
|
||||||
|
pub prompt: String,
|
||||||
|
|
||||||
|
/// The maximum number of tokens that can be generated in the chat completion.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(default = "32")]
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
|
||||||
|
/// lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or `top_p` but not both.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = 1.0)]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
|
||||||
|
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
|
||||||
|
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = 0.95)]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
|
||||||
|
#[serde(default = "bool::default")]
|
||||||
|
pub stream: bool,
|
||||||
|
|
||||||
|
#[schema(nullable = true, example = 42)]
|
||||||
|
pub seed: Option<u64>,
|
||||||
|
|
||||||
|
/// The text to append to the prompt. This is useful for completing sentences or generating a paragraph of text.
|
||||||
|
/// please see the completion_template field in the model's tokenizer_config.json file for completion template.
|
||||||
|
#[serde(default)]
|
||||||
|
pub suffix: Option<String>,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub repetition_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
|
||||||
|
/// decreasing the model's likelihood to repeat the same line verbatim.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(example = "1.0")]
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
||||||
|
pub(crate) struct Completion {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
#[schema(example = "1706270835")]
|
||||||
|
pub created: u64,
|
||||||
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
|
pub model: String,
|
||||||
|
pub system_fingerprint: String,
|
||||||
|
pub choices: Vec<CompletionComplete>,
|
||||||
|
pub usage: Usage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
|
pub(crate) struct CompletionComplete {
|
||||||
|
pub index: u32,
|
||||||
|
pub text: String,
|
||||||
|
pub logprobs: Option<Vec<f32>>,
|
||||||
|
pub finish_reason: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletion {
|
pub(crate) struct ChatCompletion {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
|
@ -347,7 +418,7 @@ pub(crate) struct ChatCompletionTopLogprob {
|
||||||
logprob: f32,
|
logprob: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
||||||
pub(crate) struct Usage {
|
pub(crate) struct Usage {
|
||||||
pub prompt_tokens: u32,
|
pub prompt_tokens: u32,
|
||||||
pub completion_tokens: u32,
|
pub completion_tokens: u32,
|
||||||
|
@ -390,7 +461,15 @@ impl ChatCompletion {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
|
pub(crate) struct CompletionCompleteChunk {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub choices: Vec<CompletionComplete>,
|
||||||
|
pub model: String,
|
||||||
|
pub system_fingerprint: String,
|
||||||
|
}
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletionChunk {
|
pub(crate) struct ChatCompletionChunk {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
|
|
|
@ -3,12 +3,16 @@ use crate::health::Health;
|
||||||
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk,
|
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||||
ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs,
|
GenerateResponse, GrammarType, HubModelInfo, HubTokenizerConfig, Infer, Info, Message,
|
||||||
ChatCompletionTopLogprob, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
|
PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage,
|
||||||
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
Validation,
|
||||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
};
|
||||||
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
|
use crate::{
|
||||||
|
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||||
|
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
|
||||||
|
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
|
||||||
|
CompletionRequest, VertexRequest, VertexResponse,
|
||||||
};
|
};
|
||||||
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
|
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
|
@ -536,6 +540,183 @@ async fn generate_stream_internal(
|
||||||
(headers, stream)
|
(headers, stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generate tokens
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v1/completions",
|
||||||
|
request_body = CompletionRequest,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Generated Text", body = ChatCompletionChunk),
|
||||||
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Request failed during generation"})),
|
||||||
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Model is overloaded"})),
|
||||||
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Input validation error"})),
|
||||||
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Incomplete generation"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
#[instrument(
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
// parameters = ? req.parameters,
|
||||||
|
total_time,
|
||||||
|
validation_time,
|
||||||
|
queue_time,
|
||||||
|
inference_time,
|
||||||
|
time_per_token,
|
||||||
|
seed,
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
async fn completions(
|
||||||
|
Extension(infer): Extension<Infer>,
|
||||||
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
|
Extension(info): Extension<Info>,
|
||||||
|
Json(req): Json<CompletionRequest>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
|
let stream = req.stream;
|
||||||
|
let max_new_tokens = req.max_tokens.or(Some(100));
|
||||||
|
let seed = req.seed;
|
||||||
|
|
||||||
|
// if suffix is present throw an error
|
||||||
|
if req.suffix.is_some() {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
|
return Err((
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Suffix is not supported and can be achieved by preprocessing the prompt."
|
||||||
|
.to_string(),
|
||||||
|
error_type: "suffix not supported".to_string(),
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// build the request passing some parameters
|
||||||
|
let generate_request = GenerateRequest {
|
||||||
|
inputs: req.prompt.to_string(),
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
best_of: None,
|
||||||
|
temperature: req.temperature,
|
||||||
|
repetition_penalty: req.repetition_penalty,
|
||||||
|
frequency_penalty: req.frequency_penalty,
|
||||||
|
top_k: None,
|
||||||
|
top_p: req.top_p,
|
||||||
|
typical_p: None,
|
||||||
|
do_sample: true,
|
||||||
|
max_new_tokens,
|
||||||
|
return_full_text: None,
|
||||||
|
stop: Vec::new(),
|
||||||
|
truncate: None,
|
||||||
|
watermark: false,
|
||||||
|
details: true,
|
||||||
|
decoder_input_details: !stream,
|
||||||
|
seed,
|
||||||
|
top_n_tokens: None,
|
||||||
|
grammar: None,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
if stream {
|
||||||
|
let on_message_callback = move |stream_token: StreamResponse| {
|
||||||
|
let event = Event::default();
|
||||||
|
|
||||||
|
let current_time = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
|
.as_secs();
|
||||||
|
|
||||||
|
event
|
||||||
|
.json_data(CompletionCompleteChunk {
|
||||||
|
id: "".to_string(),
|
||||||
|
object: "text_completion".to_string(),
|
||||||
|
created: current_time,
|
||||||
|
|
||||||
|
choices: vec![CompletionComplete {
|
||||||
|
finish_reason: "".to_string(),
|
||||||
|
index: 0,
|
||||||
|
logprobs: None,
|
||||||
|
text: stream_token.token.text,
|
||||||
|
}],
|
||||||
|
|
||||||
|
model: info.model_id.clone(),
|
||||||
|
system_fingerprint: format!(
|
||||||
|
"{}-{}",
|
||||||
|
info.version,
|
||||||
|
info.docker_label.unwrap_or("native")
|
||||||
|
),
|
||||||
|
})
|
||||||
|
.map_or_else(
|
||||||
|
|e| {
|
||||||
|
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||||
|
Event::default()
|
||||||
|
},
|
||||||
|
|data| data,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let (headers, response_stream) = generate_stream_internal(
|
||||||
|
infer,
|
||||||
|
compute_type,
|
||||||
|
Json(generate_request),
|
||||||
|
on_message_callback,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
|
Ok((headers, sse).into_response())
|
||||||
|
} else {
|
||||||
|
let (headers, Json(generation)) = generate(
|
||||||
|
Extension(infer),
|
||||||
|
Extension(compute_type),
|
||||||
|
Json(generate_request),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let current_time = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
|
.as_secs();
|
||||||
|
|
||||||
|
let details = generation.details.ok_or((
|
||||||
|
// this should never happen but handle if details are missing unexpectedly
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "No details in generation".to_string(),
|
||||||
|
error_type: "no details".to_string(),
|
||||||
|
}),
|
||||||
|
))?;
|
||||||
|
|
||||||
|
let response = Completion {
|
||||||
|
id: "".to_string(),
|
||||||
|
object: "text_completion".to_string(),
|
||||||
|
created: current_time,
|
||||||
|
model: info.model_id.clone(),
|
||||||
|
system_fingerprint: format!(
|
||||||
|
"{}-{}",
|
||||||
|
info.version,
|
||||||
|
info.docker_label.unwrap_or("native")
|
||||||
|
),
|
||||||
|
choices: vec![CompletionComplete {
|
||||||
|
finish_reason: details.finish_reason.to_string(),
|
||||||
|
index: 0,
|
||||||
|
logprobs: None,
|
||||||
|
text: generation.generated_text,
|
||||||
|
}],
|
||||||
|
usage: Usage {
|
||||||
|
prompt_tokens: details.prefill.len() as u32,
|
||||||
|
completion_tokens: details.generated_tokens,
|
||||||
|
total_tokens: details.prefill.len() as u32 + details.generated_tokens,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((headers, Json(response)).into_response())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Generate tokens
|
/// Generate tokens
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
|
@ -993,6 +1174,7 @@ pub async fn run(
|
||||||
generate,
|
generate,
|
||||||
generate_stream,
|
generate_stream,
|
||||||
chat_completions,
|
chat_completions,
|
||||||
|
completions,
|
||||||
tokenize,
|
tokenize,
|
||||||
metrics,
|
metrics,
|
||||||
),
|
),
|
||||||
|
@ -1012,6 +1194,9 @@ pub async fn run(
|
||||||
ChatCompletionLogprobs,
|
ChatCompletionLogprobs,
|
||||||
ChatCompletionTopLogprob,
|
ChatCompletionTopLogprob,
|
||||||
ChatCompletion,
|
ChatCompletion,
|
||||||
|
CompletionRequest,
|
||||||
|
CompletionComplete,
|
||||||
|
CompletionCompleteChunk,
|
||||||
GenerateParameters,
|
GenerateParameters,
|
||||||
PrefillToken,
|
PrefillToken,
|
||||||
Token,
|
Token,
|
||||||
|
@ -1184,6 +1369,7 @@ pub async fn run(
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
.route("/v1/chat/completions", post(chat_completions))
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
|
.route("/v1/completions", post(completions))
|
||||||
.route("/vertex", post(vertex_compatibility))
|
.route("/vertex", post(vertex_compatibility))
|
||||||
.route("/tokenize", post(tokenize))
|
.route("/tokenize", post(tokenize))
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
|
|
|
@ -36,7 +36,6 @@ __all__ = [
|
||||||
"Model",
|
"Model",
|
||||||
"BLOOMSharded",
|
"BLOOMSharded",
|
||||||
"CausalLM",
|
"CausalLM",
|
||||||
"FlashCausalLM",
|
|
||||||
"GalacticaSharded",
|
"GalacticaSharded",
|
||||||
"Seq2SeqLM",
|
"Seq2SeqLM",
|
||||||
"SantaCoder",
|
"SantaCoder",
|
||||||
|
@ -48,6 +47,7 @@ __all__ = [
|
||||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||||
|
|
||||||
FLASH_ATTENTION = True
|
FLASH_ATTENTION = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||||
|
|
Loading…
Reference in New Issue