feat: accept legacy request format and response (#1527)
This WIP PR (will) add support for legacy OpenAI `v1/completions` API. This should allow TGI to be a drop in replacement for OpenAI when using tools that rely on the completions api Should fix: https://github.com/huggingface/text-generation-inference/issues/1468
This commit is contained in:
parent
9ed4d2c780
commit
3dd7da2198
|
@ -231,6 +231,9 @@ class Client:
|
|||
Return the decoder input token logprobs and ids
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
grammar (`Grammar`):
|
||||
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
|
||||
of the text to match a regular expression or JSON schema.
|
||||
|
||||
Returns:
|
||||
Response: generated response
|
||||
|
@ -322,6 +325,9 @@ class Client:
|
|||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
grammar (`Grammar`):
|
||||
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
|
||||
of the text to match a regular expression or JSON schema.
|
||||
|
||||
Returns:
|
||||
Iterator[StreamResponse]: stream of generated tokens
|
||||
|
@ -592,6 +598,9 @@ class AsyncClient:
|
|||
Return the decoder input token logprobs and ids
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
grammar (`Grammar`):
|
||||
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
|
||||
of the text to match a regular expression or JSON schema.
|
||||
|
||||
Returns:
|
||||
Response: generated response
|
||||
|
@ -682,6 +691,9 @@ class AsyncClient:
|
|||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
grammar (`Grammar`):
|
||||
Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation
|
||||
of the text to match a regular expression or JSON schema.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[StreamResponse]: stream of generated tokens
|
||||
|
|
|
@ -51,6 +51,7 @@ pub struct HubModelInfo {
|
|||
#[derive(Clone, Deserialize, Default)]
|
||||
pub struct HubTokenizerConfig {
|
||||
pub chat_template: Option<String>,
|
||||
pub completion_template: Option<String>,
|
||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||
pub bos_token: Option<String>,
|
||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||
|
@ -265,6 +266,76 @@ fn default_parameters() -> GenerateParameters {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||
pub struct CompletionRequest {
|
||||
/// UNUSED
|
||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||||
pub model: String,
|
||||
|
||||
/// The prompt to generate completions for.
|
||||
#[schema(example = "What is Deep Learning?")]
|
||||
pub prompt: String,
|
||||
|
||||
/// The maximum number of tokens that can be generated in the chat completion.
|
||||
#[serde(default)]
|
||||
#[schema(default = "32")]
|
||||
pub max_tokens: Option<u32>,
|
||||
|
||||
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
|
||||
/// lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or `top_p` but not both.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = 1.0)]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
|
||||
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = 0.95)]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
#[serde(default = "bool::default")]
|
||||
pub stream: bool,
|
||||
|
||||
#[schema(nullable = true, example = 42)]
|
||||
pub seed: Option<u64>,
|
||||
|
||||
/// The text to append to the prompt. This is useful for completing sentences or generating a paragraph of text.
|
||||
/// please see the completion_template field in the model's tokenizer_config.json file for completion template.
|
||||
#[serde(default)]
|
||||
pub suffix: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub repetition_penalty: Option<f32>,
|
||||
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
|
||||
/// decreasing the model's likelihood to repeat the same line verbatim.
|
||||
#[serde(default)]
|
||||
#[schema(example = "1.0")]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
||||
pub(crate) struct Completion {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
#[schema(example = "1706270835")]
|
||||
pub created: u64,
|
||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||
pub model: String,
|
||||
pub system_fingerprint: String,
|
||||
pub choices: Vec<CompletionComplete>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct CompletionComplete {
|
||||
pub index: u32,
|
||||
pub text: String,
|
||||
pub logprobs: Option<Vec<f32>>,
|
||||
pub finish_reason: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct ChatCompletion {
|
||||
pub id: String,
|
||||
|
@ -347,7 +418,7 @@ pub(crate) struct ChatCompletionTopLogprob {
|
|||
logprob: f32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
||||
pub(crate) struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
|
@ -390,7 +461,15 @@ impl ChatCompletion {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct CompletionCompleteChunk {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub choices: Vec<CompletionComplete>,
|
||||
pub model: String,
|
||||
pub system_fingerprint: String,
|
||||
}
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct ChatCompletionChunk {
|
||||
pub id: String,
|
||||
|
|
|
@ -3,12 +3,16 @@ use crate::health::Health;
|
|||
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk,
|
||||
ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs,
|
||||
ChatCompletionTopLogprob, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
|
||||
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
|
||||
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||
GenerateResponse, GrammarType, HubModelInfo, HubTokenizerConfig, Infer, Info, Message,
|
||||
PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage,
|
||||
Validation,
|
||||
};
|
||||
use crate::{
|
||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
|
||||
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
|
||||
CompletionRequest, VertexRequest, VertexResponse,
|
||||
};
|
||||
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
|
||||
use axum::extract::Extension;
|
||||
|
@ -536,6 +540,183 @@ async fn generate_stream_internal(
|
|||
(headers, stream)
|
||||
}
|
||||
|
||||
/// Generate tokens
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v1/completions",
|
||||
request_body = CompletionRequest,
|
||||
responses(
|
||||
(status = 200, description = "Generated Text", body = ChatCompletionChunk),
|
||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Request failed during generation"})),
|
||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||
example = json ! ({"error": "Model is overloaded"})),
|
||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Input validation error"})),
|
||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||
example = json ! ({"error": "Incomplete generation"})),
|
||||
)
|
||||
)]
|
||||
#[instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
// parameters = ? req.parameters,
|
||||
total_time,
|
||||
validation_time,
|
||||
queue_time,
|
||||
inference_time,
|
||||
time_per_token,
|
||||
seed,
|
||||
)
|
||||
)]
|
||||
async fn completions(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Extension(info): Extension<Info>,
|
||||
Json(req): Json<CompletionRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
let stream = req.stream;
|
||||
let max_new_tokens = req.max_tokens.or(Some(100));
|
||||
let seed = req.seed;
|
||||
|
||||
// if suffix is present throw an error
|
||||
if req.suffix.is_some() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Suffix is not supported and can be achieved by preprocessing the prompt."
|
||||
.to_string(),
|
||||
error_type: "suffix not supported".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
// build the request passing some parameters
|
||||
let generate_request = GenerateRequest {
|
||||
inputs: req.prompt.to_string(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature: req.temperature,
|
||||
repetition_penalty: req.repetition_penalty,
|
||||
frequency_penalty: req.frequency_penalty,
|
||||
top_k: None,
|
||||
top_p: req.top_p,
|
||||
typical_p: None,
|
||||
do_sample: true,
|
||||
max_new_tokens,
|
||||
return_full_text: None,
|
||||
stop: Vec::new(),
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: true,
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
},
|
||||
};
|
||||
|
||||
if stream {
|
||||
let on_message_callback = move |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
event
|
||||
.json_data(CompletionCompleteChunk {
|
||||
id: "".to_string(),
|
||||
object: "text_completion".to_string(),
|
||||
created: current_time,
|
||||
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: "".to_string(),
|
||||
index: 0,
|
||||
logprobs: None,
|
||||
text: stream_token.token.text,
|
||||
}],
|
||||
|
||||
model: info.model_id.clone(),
|
||||
system_fingerprint: format!(
|
||||
"{}-{}",
|
||||
info.version,
|
||||
info.docker_label.unwrap_or("native")
|
||||
),
|
||||
})
|
||||
.map_or_else(
|
||||
|e| {
|
||||
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||
Event::default()
|
||||
},
|
||||
|data| data,
|
||||
)
|
||||
};
|
||||
|
||||
let (headers, response_stream) = generate_stream_internal(
|
||||
infer,
|
||||
compute_type,
|
||||
Json(generate_request),
|
||||
on_message_callback,
|
||||
)
|
||||
.await;
|
||||
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
Ok((headers, sse).into_response())
|
||||
} else {
|
||||
let (headers, Json(generation)) = generate(
|
||||
Extension(infer),
|
||||
Extension(compute_type),
|
||||
Json(generate_request),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let details = generation.details.ok_or((
|
||||
// this should never happen but handle if details are missing unexpectedly
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "No details in generation".to_string(),
|
||||
error_type: "no details".to_string(),
|
||||
}),
|
||||
))?;
|
||||
|
||||
let response = Completion {
|
||||
id: "".to_string(),
|
||||
object: "text_completion".to_string(),
|
||||
created: current_time,
|
||||
model: info.model_id.clone(),
|
||||
system_fingerprint: format!(
|
||||
"{}-{}",
|
||||
info.version,
|
||||
info.docker_label.unwrap_or("native")
|
||||
),
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: details.finish_reason.to_string(),
|
||||
index: 0,
|
||||
logprobs: None,
|
||||
text: generation.generated_text,
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: details.prefill.len() as u32,
|
||||
completion_tokens: details.generated_tokens,
|
||||
total_tokens: details.prefill.len() as u32 + details.generated_tokens,
|
||||
},
|
||||
};
|
||||
|
||||
Ok((headers, Json(response)).into_response())
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate tokens
|
||||
#[utoipa::path(
|
||||
post,
|
||||
|
@ -993,6 +1174,7 @@ pub async fn run(
|
|||
generate,
|
||||
generate_stream,
|
||||
chat_completions,
|
||||
completions,
|
||||
tokenize,
|
||||
metrics,
|
||||
),
|
||||
|
@ -1012,6 +1194,9 @@ pub async fn run(
|
|||
ChatCompletionLogprobs,
|
||||
ChatCompletionTopLogprob,
|
||||
ChatCompletion,
|
||||
CompletionRequest,
|
||||
CompletionComplete,
|
||||
CompletionCompleteChunk,
|
||||
GenerateParameters,
|
||||
PrefillToken,
|
||||
Token,
|
||||
|
@ -1184,6 +1369,7 @@ pub async fn run(
|
|||
.route("/generate", post(generate))
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/v1/completions", post(completions))
|
||||
.route("/vertex", post(vertex_compatibility))
|
||||
.route("/tokenize", post(tokenize))
|
||||
.route("/health", get(health))
|
||||
|
|
|
@ -36,7 +36,6 @@ __all__ = [
|
|||
"Model",
|
||||
"BLOOMSharded",
|
||||
"CausalLM",
|
||||
"FlashCausalLM",
|
||||
"GalacticaSharded",
|
||||
"Seq2SeqLM",
|
||||
"SantaCoder",
|
||||
|
@ -48,6 +47,7 @@ __all__ = [
|
|||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||
|
||||
FLASH_ATTENTION = True
|
||||
|
||||
try:
|
||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||
|
|
Loading…
Reference in New Issue