diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index e1de253b..465bd4fc 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -231,6 +231,9 @@ class Client: Return the decoder input token logprobs and ids top_n_tokens (`int`): Return the `n` most likely tokens at each step + grammar (`Grammar`): + Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation + of the text to match a regular expression or JSON schema. Returns: Response: generated response @@ -322,6 +325,9 @@ class Client: Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) top_n_tokens (`int`): Return the `n` most likely tokens at each step + grammar (`Grammar`): + Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation + of the text to match a regular expression or JSON schema. Returns: Iterator[StreamResponse]: stream of generated tokens @@ -592,6 +598,9 @@ class AsyncClient: Return the decoder input token logprobs and ids top_n_tokens (`int`): Return the `n` most likely tokens at each step + grammar (`Grammar`): + Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation + of the text to match a regular expression or JSON schema. Returns: Response: generated response @@ -682,6 +691,9 @@ class AsyncClient: Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) top_n_tokens (`int`): Return the `n` most likely tokens at each step + grammar (`Grammar`): + Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation + of the text to match a regular expression or JSON schema. Returns: AsyncIterator[StreamResponse]: stream of generated tokens diff --git a/router/src/lib.rs b/router/src/lib.rs index d89bacb5..a97b9b50 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -51,6 +51,7 @@ pub struct HubModelInfo { #[derive(Clone, Deserialize, Default)] pub struct HubTokenizerConfig { pub chat_template: Option, + pub completion_template: Option, #[serde(deserialize_with = "token_serde::deserialize")] pub bos_token: Option, #[serde(deserialize_with = "token_serde::deserialize")] @@ -265,6 +266,76 @@ fn default_parameters() -> GenerateParameters { } } +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] +pub struct CompletionRequest { + /// UNUSED + #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] + /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. + pub model: String, + + /// The prompt to generate completions for. + #[schema(example = "What is Deep Learning?")] + pub prompt: String, + + /// The maximum number of tokens that can be generated in the chat completion. + #[serde(default)] + #[schema(default = "32")] + pub max_tokens: Option, + + /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while + /// lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or `top_p` but not both. + #[serde(default)] + #[schema(nullable = true, example = 1.0)] + pub temperature: Option, + + /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the + /// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. + #[serde(default)] + #[schema(nullable = true, example = 0.95)] + pub top_p: Option, + + #[serde(default = "bool::default")] + pub stream: bool, + + #[schema(nullable = true, example = 42)] + pub seed: Option, + + /// The text to append to the prompt. This is useful for completing sentences or generating a paragraph of text. + /// please see the completion_template field in the model's tokenizer_config.json file for completion template. + #[serde(default)] + pub suffix: Option, + + #[serde(default)] + pub repetition_penalty: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, + /// decreasing the model's likelihood to repeat the same line verbatim. + #[serde(default)] + #[schema(example = "1.0")] + pub frequency_penalty: Option, +} + +#[derive(Clone, Deserialize, Serialize, ToSchema, Default)] +pub(crate) struct Completion { + pub id: String, + pub object: String, + #[schema(example = "1706270835")] + pub created: u64, + #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] + pub model: String, + pub system_fingerprint: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Clone, Deserialize, Serialize, ToSchema)] +pub(crate) struct CompletionComplete { + pub index: u32, + pub text: String, + pub logprobs: Option>, + pub finish_reason: String, +} + #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletion { pub id: String, @@ -347,7 +418,7 @@ pub(crate) struct ChatCompletionTopLogprob { logprob: f32, } -#[derive(Clone, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Deserialize, Serialize, ToSchema, Default)] pub(crate) struct Usage { pub prompt_tokens: u32, pub completion_tokens: u32, @@ -390,7 +461,15 @@ impl ChatCompletion { } } } - +#[derive(Clone, Deserialize, Serialize, ToSchema)] +pub(crate) struct CompletionCompleteChunk { + pub id: String, + pub object: String, + pub created: u64, + pub choices: Vec, + pub model: String, + pub system_fingerprint: String, +} #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletionChunk { pub id: String, diff --git a/router/src/server.rs b/router/src/server.rs index 9c7046d9..9c956a73 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -3,12 +3,16 @@ use crate::health::Health; use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; use crate::{ - BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, - ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, - ChatCompletionTopLogprob, ChatRequest, CompatGenerateRequest, Details, ErrorResponse, - FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, - HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, - StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse, + BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, + GenerateResponse, GrammarType, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, + PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage, + Validation, +}; +use crate::{ + ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, + ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, + ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, + CompletionRequest, VertexRequest, VertexResponse, }; use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools}; use axum::extract::Extension; @@ -536,6 +540,183 @@ async fn generate_stream_internal( (headers, stream) } +/// Generate tokens +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/v1/completions", + request_body = CompletionRequest, + responses( + (status = 200, description = "Generated Text", body = ChatCompletionChunk), + (status = 424, description = "Generation Error", body = ErrorResponse, + example = json ! ({"error": "Request failed during generation"})), + (status = 429, description = "Model is overloaded", body = ErrorResponse, + example = json ! ({"error": "Model is overloaded"})), + (status = 422, description = "Input validation error", body = ErrorResponse, + example = json ! ({"error": "Input validation error"})), + (status = 500, description = "Incomplete generation", body = ErrorResponse, + example = json ! ({"error": "Incomplete generation"})), + ) + )] +#[instrument( + skip_all, + fields( + // parameters = ? req.parameters, + total_time, + validation_time, + queue_time, + inference_time, + time_per_token, + seed, + ) + )] +async fn completions( + Extension(infer): Extension, + Extension(compute_type): Extension, + Extension(info): Extension, + Json(req): Json, +) -> Result)> { + metrics::increment_counter!("tgi_request_count"); + + let stream = req.stream; + let max_new_tokens = req.max_tokens.or(Some(100)); + let seed = req.seed; + + // if suffix is present throw an error + if req.suffix.is_some() { + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: "Suffix is not supported and can be achieved by preprocessing the prompt." + .to_string(), + error_type: "suffix not supported".to_string(), + }), + )); + } + + // build the request passing some parameters + let generate_request = GenerateRequest { + inputs: req.prompt.to_string(), + parameters: GenerateParameters { + best_of: None, + temperature: req.temperature, + repetition_penalty: req.repetition_penalty, + frequency_penalty: req.frequency_penalty, + top_k: None, + top_p: req.top_p, + typical_p: None, + do_sample: true, + max_new_tokens, + return_full_text: None, + stop: Vec::new(), + truncate: None, + watermark: false, + details: true, + decoder_input_details: !stream, + seed, + top_n_tokens: None, + grammar: None, + }, + }; + + if stream { + let on_message_callback = move |stream_token: StreamResponse| { + let event = Event::default(); + + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + event + .json_data(CompletionCompleteChunk { + id: "".to_string(), + object: "text_completion".to_string(), + created: current_time, + + choices: vec![CompletionComplete { + finish_reason: "".to_string(), + index: 0, + logprobs: None, + text: stream_token.token.text, + }], + + model: info.model_id.clone(), + system_fingerprint: format!( + "{}-{}", + info.version, + info.docker_label.unwrap_or("native") + ), + }) + .map_or_else( + |e| { + println!("Failed to serialize ChatCompletionChunk: {:?}", e); + Event::default() + }, + |data| data, + ) + }; + + let (headers, response_stream) = generate_stream_internal( + infer, + compute_type, + Json(generate_request), + on_message_callback, + ) + .await; + + let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); + Ok((headers, sse).into_response()) + } else { + let (headers, Json(generation)) = generate( + Extension(infer), + Extension(compute_type), + Json(generate_request), + ) + .await?; + + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + let details = generation.details.ok_or(( + // this should never happen but handle if details are missing unexpectedly + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "No details in generation".to_string(), + error_type: "no details".to_string(), + }), + ))?; + + let response = Completion { + id: "".to_string(), + object: "text_completion".to_string(), + created: current_time, + model: info.model_id.clone(), + system_fingerprint: format!( + "{}-{}", + info.version, + info.docker_label.unwrap_or("native") + ), + choices: vec![CompletionComplete { + finish_reason: details.finish_reason.to_string(), + index: 0, + logprobs: None, + text: generation.generated_text, + }], + usage: Usage { + prompt_tokens: details.prefill.len() as u32, + completion_tokens: details.generated_tokens, + total_tokens: details.prefill.len() as u32 + details.generated_tokens, + }, + }; + + Ok((headers, Json(response)).into_response()) + } +} + /// Generate tokens #[utoipa::path( post, @@ -993,6 +1174,7 @@ pub async fn run( generate, generate_stream, chat_completions, + completions, tokenize, metrics, ), @@ -1012,6 +1194,9 @@ pub async fn run( ChatCompletionLogprobs, ChatCompletionTopLogprob, ChatCompletion, + CompletionRequest, + CompletionComplete, + CompletionCompleteChunk, GenerateParameters, PrefillToken, Token, @@ -1184,6 +1369,7 @@ pub async fn run( .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/chat/completions", post(chat_completions)) + .route("/v1/completions", post(completions)) .route("/vertex", post(vertex_compatibility)) .route("/tokenize", post(tokenize)) .route("/health", get(health)) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e7b0b9e2..684f4524 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -36,7 +36,6 @@ __all__ = [ "Model", "BLOOMSharded", "CausalLM", - "FlashCausalLM", "GalacticaSharded", "Seq2SeqLM", "SantaCoder", @@ -48,6 +47,7 @@ __all__ = [ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." FLASH_ATTENTION = True + try: from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_neox import FlashNeoXSharded