diff --git a/backends/client/src/v2/pb/generate.v2.rs b/backends/client/src/v2/pb/generate.v2.rs new file mode 100644 index 00000000..3af5670b --- /dev/null +++ b/backends/client/src/v2/pb/generate.v2.rs @@ -0,0 +1,613 @@ +// This file is @generated by prost-build. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct HealthRequest {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct HealthResponse {} +/// / Empty request +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct InfoRequest {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct InfoResponse { + #[prost(bool, tag = "1")] + pub requires_padding: bool, + #[prost(string, tag = "2")] + pub dtype: ::prost::alloc::string::String, + #[prost(string, tag = "3")] + pub device_type: ::prost::alloc::string::String, + #[prost(uint32, optional, tag = "4")] + pub window_size: ::core::option::Option, + #[prost(uint32, tag = "5")] + pub speculate: u32, +} +/// / Empty request +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ServiceDiscoveryRequest {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ServiceDiscoveryResponse { + /// / Other shards urls + #[prost(string, repeated, tag = "1")] + pub urls: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ClearCacheRequest { + /// / Optional batch id + #[prost(uint64, optional, tag = "1")] + pub id: ::core::option::Option, +} +/// / Empty response +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ClearCacheResponse {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct NextTokenChooserParameters { + /// / exponential scaling output probability distribution + #[prost(float, tag = "1")] + pub temperature: f32, + /// / restricting to the k highest probability elements + #[prost(uint32, tag = "2")] + pub top_k: u32, + /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off + #[prost(float, tag = "3")] + pub top_p: f32, + /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off + #[prost(float, tag = "4")] + pub typical_p: f32, + /// / apply sampling on the logits + #[prost(bool, tag = "5")] + pub do_sample: bool, + /// / random seed for sampling + #[prost(uint64, tag = "6")] + pub seed: u64, + /// / repetition penalty + #[prost(float, tag = "7")] + pub repetition_penalty: f32, + /// / frequency penalty + #[prost(float, tag = "9")] + pub frequency_penalty: f32, + /// / token watermarking using "A Watermark for Large Language Models" + #[prost(bool, tag = "8")] + pub watermark: bool, + /// / grammar (applied if not empty) + #[prost(string, tag = "10")] + pub grammar: ::prost::alloc::string::String, + /// / grammar type + #[prost(enumeration = "GrammarType", tag = "11")] + pub grammar_type: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StoppingCriteriaParameters { + /// / Maximum number of generated tokens + #[prost(uint32, tag = "1")] + pub max_new_tokens: u32, + /// / Optional stopping sequences + #[prost(string, repeated, tag = "2")] + pub stop_sequences: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// / Ignore end of sequence token + /// / used for benchmarking + #[prost(bool, tag = "3")] + pub ignore_eos_token: bool, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Request { + /// / Request ID + #[prost(uint64, tag = "1")] + pub id: u64, + /// / The generation context + #[prost(string, tag = "2")] + pub inputs: ::prost::alloc::string::String, + /// / Context truncation + #[prost(uint32, tag = "3")] + pub truncate: u32, + /// / Next Token Chooser Parameters + #[prost(message, optional, tag = "4")] + pub parameters: ::core::option::Option, + /// / Stopping Criteria Parameters + #[prost(message, optional, tag = "5")] + pub stopping_parameters: ::core::option::Option, + /// / Return prefill logprobs + #[prost(bool, tag = "6")] + pub prefill_logprobs: bool, + /// / Return most likely n tokens + #[prost(uint32, tag = "7")] + pub top_n_tokens: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Batch { + /// / Batch ID + #[prost(uint64, tag = "1")] + pub id: u64, + /// / Individual requests + #[prost(message, repeated, tag = "2")] + pub requests: ::prost::alloc::vec::Vec, + /// / Batch size (==len(requests)) + #[prost(uint32, tag = "3")] + pub size: u32, + /// / Maximum number of tokens this batch will grow to + #[prost(uint32, tag = "4")] + pub max_tokens: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CachedBatch { + /// / Batch ID + #[prost(uint64, tag = "1")] + pub id: u64, + /// / Individual requests ids + #[prost(uint64, repeated, tag = "2")] + pub request_ids: ::prost::alloc::vec::Vec, + /// / Batch size (==len(requests)) + #[prost(uint32, tag = "3")] + pub size: u32, + /// / Maximum number of tokens this batch will grow to + #[prost(uint32, tag = "4")] + pub max_tokens: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct GeneratedText { + /// / Output + #[prost(string, tag = "1")] + pub text: ::prost::alloc::string::String, + /// / Number of generated tokens + #[prost(uint32, tag = "2")] + pub generated_tokens: u32, + /// / Finish reason + #[prost(enumeration = "FinishReason", tag = "3")] + pub finish_reason: i32, + /// / Seed + #[prost(uint64, optional, tag = "4")] + pub seed: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Tokens { + /// / Token IDs + #[prost(uint32, repeated, tag = "1")] + pub ids: ::prost::alloc::vec::Vec, + /// / Logprobs + #[prost(float, repeated, tag = "2")] + pub logprobs: ::prost::alloc::vec::Vec, + /// / tokens + #[prost(string, repeated, tag = "3")] + pub texts: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// / special + #[prost(bool, repeated, tag = "4")] + pub is_special: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Generation { + /// / Request ID + #[prost(uint64, tag = "1")] + pub request_id: u64, + /// / Prefill tokens (optional) + #[prost(message, optional, tag = "2")] + pub prefill_tokens: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub tokens: ::core::option::Option, + /// / Complete generated text + #[prost(message, optional, tag = "4")] + pub generated_text: ::core::option::Option, + /// / Top tokens + #[prost(message, repeated, tag = "5")] + pub top_tokens: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FilterBatchRequest { + /// / Batch ID + #[prost(uint64, tag = "1")] + pub batch_id: u64, + /// / Requests to keep + #[prost(uint64, repeated, tag = "2")] + pub request_ids: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FilterBatchResponse { + /// / Filtered Batch (cached) + #[prost(message, optional, tag = "1")] + pub batch: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PrefillRequest { + /// / Batch + #[prost(message, optional, tag = "1")] + pub batch: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PrefillResponse { + /// / Generation + #[prost(message, repeated, tag = "1")] + pub generations: ::prost::alloc::vec::Vec, + /// / Next batch (cached) + #[prost(message, optional, tag = "2")] + pub batch: ::core::option::Option, + /// / Forward elapsed time in nanoseconds + #[prost(uint64, tag = "3")] + pub forward_ns: u64, + /// / Decode elapsed time in nanoseconds + #[prost(uint64, tag = "4")] + pub decode_ns: u64, + /// / Total elapsed time in nanoseconds + #[prost(uint64, tag = "5")] + pub total_ns: u64, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DecodeRequest { + /// / Cached batches + #[prost(message, repeated, tag = "1")] + pub batches: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DecodeResponse { + /// / Decodes + #[prost(message, repeated, tag = "1")] + pub generations: ::prost::alloc::vec::Vec, + /// / Next batch (cached) + #[prost(message, optional, tag = "2")] + pub batch: ::core::option::Option, + /// / Forward elapsed time in nanoseconds + #[prost(uint64, tag = "3")] + pub forward_ns: u64, + /// / Decode elapsed time in nanoseconds + #[prost(uint64, tag = "4")] + pub decode_ns: u64, + /// / Total elapsed time in nanoseconds + #[prost(uint64, tag = "5")] + pub total_ns: u64, + /// / Concatenate elapsed time in nanoseconds + #[prost(uint64, optional, tag = "6")] + pub concat_ns: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct WarmupRequest { + /// / Batch to warmup on + #[prost(message, optional, tag = "1")] + pub batch: ::core::option::Option, + #[prost(uint32, tag = "2")] + pub max_input_length: u32, + #[prost(uint32, tag = "3")] + pub max_prefill_tokens: u32, + #[prost(uint32, tag = "4")] + pub max_total_tokens: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct WarmupResponse { + /// / Maximum number of tokens supported by the model + #[prost(uint32, optional, tag = "1")] + pub max_supported_total_tokens: ::core::option::Option, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum GrammarType { + None = 0, + Json = 1, + Regex = 2, +} +impl GrammarType { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + GrammarType::None => "GRAMMAR_TYPE_NONE", + GrammarType::Json => "GRAMMAR_TYPE_JSON", + GrammarType::Regex => "GRAMMAR_TYPE_REGEX", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "GRAMMAR_TYPE_NONE" => Some(Self::None), + "GRAMMAR_TYPE_JSON" => Some(Self::Json), + "GRAMMAR_TYPE_REGEX" => Some(Self::Regex), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum FinishReason { + Length = 0, + EosToken = 1, + StopSequence = 2, +} +impl FinishReason { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + FinishReason::Length => "FINISH_REASON_LENGTH", + FinishReason::EosToken => "FINISH_REASON_EOS_TOKEN", + FinishReason::StopSequence => "FINISH_REASON_STOP_SEQUENCE", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "FINISH_REASON_LENGTH" => Some(Self::Length), + "FINISH_REASON_EOS_TOKEN" => Some(Self::EosToken), + "FINISH_REASON_STOP_SEQUENCE" => Some(Self::StopSequence), + _ => None, + } + } +} +/// Generated client implementations. +pub mod text_generation_service_client { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::http::Uri; + use tonic::codegen::*; + #[derive(Debug, Clone)] + pub struct TextGenerationServiceClient { + inner: tonic::client::Grpc, + } + impl TextGenerationServiceClient { + /// Attempt to create a new client by connecting to a given endpoint. + pub async fn connect(dst: D) -> Result + where + D: TryInto, + D::Error: Into, + { + let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; + Ok(Self::new(conn)) + } + } + impl TextGenerationServiceClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + Send + 'static, + ::Error: Into + Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> TextGenerationServiceClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + >>::Error: + Into + Send + Sync, + { + TextGenerationServiceClient::new(InterceptedService::new(inner, interceptor)) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + /// / Model Info + pub async fn info( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = + http::uri::PathAndQuery::from_static("/generate.v2.TextGenerationService/Info"); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Info")); + self.inner.unary(req, path, codec).await + } + /// / Service discovery + pub async fn service_discovery( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> + { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/ServiceDiscovery", + ); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new( + "generate.v2.TextGenerationService", + "ServiceDiscovery", + )); + self.inner.unary(req, path, codec).await + } + /// / Empties batch cache + pub async fn clear_cache( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> + { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/ClearCache", + ); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new( + "generate.v2.TextGenerationService", + "ClearCache", + )); + self.inner.unary(req, path, codec).await + } + /// / Remove requests from a cached batch + pub async fn filter_batch( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> + { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/FilterBatch", + ); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new( + "generate.v2.TextGenerationService", + "FilterBatch", + )); + self.inner.unary(req, path, codec).await + } + /// / Warmup the model and compute max cache size + pub async fn warmup( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = + http::uri::PathAndQuery::from_static("/generate.v2.TextGenerationService/Warmup"); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new( + "generate.v2.TextGenerationService", + "Warmup", + )); + self.inner.unary(req, path, codec).await + } + /// / Prefill batch and decode first token + pub async fn prefill( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = + http::uri::PathAndQuery::from_static("/generate.v2.TextGenerationService/Prefill"); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new( + "generate.v2.TextGenerationService", + "Prefill", + )); + self.inner.unary(req, path, codec).await + } + /// / Decode token for a list of prefilled batches + pub async fn decode( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = + http::uri::PathAndQuery::from_static("/generate.v2.TextGenerationService/Decode"); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new( + "generate.v2.TextGenerationService", + "Decode", + )); + self.inner.unary(req, path, codec).await + } + /// / Health check + pub async fn health( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = + http::uri::PathAndQuery::from_static("/generate.v2.TextGenerationService/Health"); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new( + "generate.v2.TextGenerationService", + "Health", + )); + self.inner.unary(req, path, codec).await + } + } +} diff --git a/backends/client/src/v2/pb/mod.rs b/backends/client/src/v2/pb/mod.rs new file mode 100644 index 00000000..095ead1f --- /dev/null +++ b/backends/client/src/v2/pb/mod.rs @@ -0,0 +1,6 @@ +// This file is @generated by prost-build. +pub mod generate { + pub mod v2 { + include!("generate.v2.rs"); + } +} diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index d43f789e..968c1f45 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -107,20 +107,22 @@ impl Client { #[instrument(skip_all)] pub async fn warmup( &mut self, - max_input_length: u32, + max_input_tokens: Option, max_prefill_tokens: u32, - max_total_tokens: u32, + max_total_tokens: Option, max_batch_size: Option, - ) -> Result> { + ) -> Result<(Option, u32, u32)> { let mut n_tokens = 0; let mut requests = Vec::new(); // Create requests while n_tokens < max_prefill_tokens { - let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + let mut truncate = max_prefill_tokens - n_tokens; + if let Some(max_input_tokens) = max_input_tokens { + truncate = min(max_input_tokens, truncate); + } let mut input_chunks = Vec::new(); - input_chunks - .push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into()); + input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into()); if n_tokens == 0 { input_chunks.push( Chunk::Image(Image { @@ -136,7 +138,7 @@ impl Client { // been updated to support chunks. let mut inputs = String::new(); - inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + inputs.push_str(&"_test ".to_string().repeat(truncate as usize)); if n_tokens == 0 { // 1 request is enough to test vision heads. // Sending images on other queries messes up easily with truncation. @@ -145,6 +147,12 @@ impl Client { )); } + let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens { + max_total_tokens - truncate + } else { + 1 + }; + requests.push(Request { id: 0, inputs, @@ -175,7 +183,7 @@ impl Client { grammar_type: GrammarType::None as i32, }), stopping_parameters: Some(StoppingCriteriaParameters { - max_new_tokens: max_total_tokens - truncate, + max_new_tokens, stop_sequences: vec![], ignore_eos_token: true, }), @@ -183,7 +191,7 @@ impl Client { top_n_tokens: 20, adapter_id: None, }); - n_tokens += max_input_length; + n_tokens += truncate; // Check max_batch_size if Some(requests.len()) == max_batch_size { @@ -195,19 +203,23 @@ impl Client { id: 0, size: requests.len() as u32, requests, - max_tokens: max_input_length, + max_tokens: max_input_tokens.unwrap_or(0), max_blocks: 0, }; let request = tonic::Request::new(WarmupRequest { batch: Some(batch), - max_input_length, + max_input_tokens, max_prefill_tokens, max_total_tokens, }) .inject_context(); let response = self.stub.warmup(request).await?.into_inner(); - Ok(response.max_supported_total_tokens) + Ok(( + response.max_supported_total_tokens, + response.max_input_tokens, + response.max_total_tokens, + )) } /// Generate one token for each request in the given batch diff --git a/backends/client/src/v3/pb/generate.v3.rs b/backends/client/src/v3/pb/generate.v3.rs new file mode 100644 index 00000000..95423e30 --- /dev/null +++ b/backends/client/src/v3/pb/generate.v3.rs @@ -0,0 +1,699 @@ +// This file is @generated by prost-build. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct HealthRequest {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct HealthResponse {} +/// / Empty request +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct InfoRequest {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct InfoResponse { + #[prost(bool, tag = "1")] + pub requires_padding: bool, + #[prost(string, tag = "2")] + pub dtype: ::prost::alloc::string::String, + #[prost(string, tag = "3")] + pub device_type: ::prost::alloc::string::String, + #[prost(uint32, optional, tag = "4")] + pub window_size: ::core::option::Option, + #[prost(uint32, tag = "5")] + pub speculate: u32, + #[prost(bool, tag = "6")] + pub support_chunking: bool, + #[prost(bool, tag = "7")] + pub use_prefix_caching: bool, + #[prost(string, tag = "8")] + pub attention_impl: ::prost::alloc::string::String, + #[prost(uint32, tag = "9")] + pub block_size: u32, +} +/// / Empty request +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ServiceDiscoveryRequest {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ServiceDiscoveryResponse { + /// / Other shards urls + #[prost(string, repeated, tag = "1")] + pub urls: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ClearCacheRequest { + /// / Optional batch id + #[prost(uint64, optional, tag = "1")] + pub id: ::core::option::Option, +} +/// / Empty response +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ClearCacheResponse {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Image { + /// / Binary image data. + #[prost(bytes = "vec", tag = "1")] + pub data: ::prost::alloc::vec::Vec, + /// / Image MIME type. + #[prost(string, tag = "2")] + pub mimetype: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct InputChunk { + #[prost(oneof = "input_chunk::Chunk", tags = "1, 2")] + pub chunk: ::core::option::Option, +} +/// Nested message and enum types in `InputChunk`. +pub mod input_chunk { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Chunk { + /// / Plain text data + #[prost(string, tag = "1")] + Text(::prost::alloc::string::String), + /// / Image data + #[prost(message, tag = "2")] + Image(super::Image), + } +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Input { + #[prost(message, repeated, tag = "1")] + pub chunks: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct NextTokenChooserParameters { + /// / exponential scaling output probability distribution + #[prost(float, tag = "1")] + pub temperature: f32, + /// / restricting to the k highest probability elements + #[prost(uint32, tag = "2")] + pub top_k: u32, + /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off + #[prost(float, tag = "3")] + pub top_p: f32, + /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off + #[prost(float, tag = "4")] + pub typical_p: f32, + /// / apply sampling on the logits + #[prost(bool, tag = "5")] + pub do_sample: bool, + /// / random seed for sampling + #[prost(uint64, tag = "6")] + pub seed: u64, + /// / repetition penalty + #[prost(float, tag = "7")] + pub repetition_penalty: f32, + /// / frequency penalty + #[prost(float, tag = "9")] + pub frequency_penalty: f32, + /// / token watermarking using "A Watermark for Large Language Models" + #[prost(bool, tag = "8")] + pub watermark: bool, + /// / grammar (applied if not empty) + #[prost(string, tag = "10")] + pub grammar: ::prost::alloc::string::String, + /// / grammar type + #[prost(enumeration = "GrammarType", tag = "11")] + pub grammar_type: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StoppingCriteriaParameters { + /// / Maximum number of generated tokens + #[prost(uint32, tag = "1")] + pub max_new_tokens: u32, + /// / Optional stopping sequences + #[prost(string, repeated, tag = "2")] + pub stop_sequences: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// / Ignore end of sequence token + /// / used for benchmarking + #[prost(bool, tag = "3")] + pub ignore_eos_token: bool, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Request { + /// / Request ID + #[prost(uint64, tag = "1")] + pub id: u64, + /// / The generation context as chunks + #[prost(message, optional, tag = "8")] + pub input_chunks: ::core::option::Option, + /// / The generation context, stringified input_chunks + #[prost(string, tag = "2")] + pub inputs: ::prost::alloc::string::String, + /// / Context truncation + #[prost(uint32, tag = "3")] + pub truncate: u32, + /// / Next Token Chooser Parameters + #[prost(message, optional, tag = "4")] + pub parameters: ::core::option::Option, + /// / Stopping Criteria Parameters + #[prost(message, optional, tag = "5")] + pub stopping_parameters: ::core::option::Option, + /// / Return prefill logprobs + #[prost(bool, tag = "6")] + pub prefill_logprobs: bool, + /// / Return most likely n tokens + #[prost(uint32, tag = "7")] + pub top_n_tokens: u32, + /// / Paged attention blocks + #[prost(uint32, repeated, tag = "9")] + pub blocks: ::prost::alloc::vec::Vec, + /// / Paged attention slots + #[prost(uint32, repeated, tag = "10")] + pub slots: ::prost::alloc::vec::Vec, + /// / LORA adapter index + #[prost(string, optional, tag = "11")] + pub adapter_id: ::core::option::Option<::prost::alloc::string::String>, + /// / Tokens that can be retrieved from the KV cache. + /// / This value is set for the first prefill and never reset + #[prost(uint32, tag = "12")] + pub cache_len: u32, + /// / Context truncation + #[prost(bool, tag = "13")] + pub add_special_tokens: bool, + /// / Chunk of tokens that must be computed for the first prefill + /// / This value is set for the first prefill and never reset + #[prost(uint32, optional, tag = "14")] + pub chunk_len: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Batch { + /// / Batch ID + #[prost(uint64, tag = "1")] + pub id: u64, + /// / Individual requests + #[prost(message, repeated, tag = "2")] + pub requests: ::prost::alloc::vec::Vec, + /// / Batch size (==len(requests)) + #[prost(uint32, tag = "3")] + pub size: u32, + /// / Maximum number of tokens this batch will grow to + #[prost(uint32, tag = "4")] + pub max_tokens: u32, + /// / Maximum number of Paged Attention blocks + #[prost(uint32, tag = "5")] + pub max_blocks: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CachedBatch { + /// / Batch ID + #[prost(uint64, tag = "1")] + pub id: u64, + /// / Individual requests ids + #[prost(uint64, repeated, tag = "2")] + pub request_ids: ::prost::alloc::vec::Vec, + /// / Batch size (==len(requests)) + #[prost(uint32, tag = "3")] + pub size: u32, + /// / Maximum number of tokens this batch will grow to + #[prost(uint32, tag = "4")] + pub max_tokens: u32, + /// / Number of tokens in the next forward + #[prost(uint32, tag = "5")] + pub current_tokens: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct GeneratedText { + /// / Output + #[prost(string, tag = "1")] + pub text: ::prost::alloc::string::String, + /// / Number of generated tokens + #[prost(uint32, tag = "2")] + pub generated_tokens: u32, + /// / Finish reason + #[prost(enumeration = "FinishReason", tag = "3")] + pub finish_reason: i32, + /// / Seed + #[prost(uint64, optional, tag = "4")] + pub seed: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Tokens { + /// / Token IDs + #[prost(uint32, repeated, tag = "1")] + pub ids: ::prost::alloc::vec::Vec, + /// / Logprobs + #[prost(float, repeated, tag = "2")] + pub logprobs: ::prost::alloc::vec::Vec, + /// / tokens + #[prost(string, repeated, tag = "3")] + pub texts: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// / special + #[prost(bool, repeated, tag = "4")] + pub is_special: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Generation { + /// / Request ID + #[prost(uint64, tag = "1")] + pub request_id: u64, + /// / Prefill tokens (optional) + #[prost(message, optional, tag = "2")] + pub prefill_tokens: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub tokens: ::core::option::Option, + /// / Complete generated text + #[prost(message, optional, tag = "4")] + pub generated_text: ::core::option::Option, + /// / Top tokens + #[prost(message, repeated, tag = "5")] + pub top_tokens: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FilterBatchRequest { + /// / Batch ID + #[prost(uint64, tag = "1")] + pub batch_id: u64, + /// / Requests to keep + #[prost(uint64, repeated, tag = "2")] + pub request_ids: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FilterBatchResponse { + /// / Filtered Batch (cached) + #[prost(message, optional, tag = "1")] + pub batch: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PrefillRequest { + /// / Batch + #[prost(message, optional, tag = "1")] + pub batch: ::core::option::Option, + /// / Optional cached batch + #[prost(message, optional, tag = "2")] + pub cached_batch: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PrefillResponse { + /// / Generation + #[prost(message, repeated, tag = "1")] + pub generations: ::prost::alloc::vec::Vec, + /// / Next batch (cached) + #[prost(message, optional, tag = "2")] + pub batch: ::core::option::Option, + /// / Forward elapsed time in nanoseconds + #[prost(uint64, tag = "3")] + pub forward_ns: u64, + /// / Decode elapsed time in nanoseconds + #[prost(uint64, tag = "4")] + pub decode_ns: u64, + /// / Total elapsed time in nanoseconds + #[prost(uint64, tag = "5")] + pub total_ns: u64, + /// / Concatenate elapsed time in nanoseconds + #[prost(uint64, optional, tag = "6")] + pub concat_ns: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DecodeRequest { + /// / Cached batches + #[prost(message, repeated, tag = "1")] + pub batches: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DecodeResponse { + /// / Decodes + #[prost(message, repeated, tag = "1")] + pub generations: ::prost::alloc::vec::Vec, + /// / Next batch (cached) + #[prost(message, optional, tag = "2")] + pub batch: ::core::option::Option, + /// / Forward elapsed time in nanoseconds + #[prost(uint64, tag = "3")] + pub forward_ns: u64, + /// / Decode elapsed time in nanoseconds + #[prost(uint64, tag = "4")] + pub decode_ns: u64, + /// / Total elapsed time in nanoseconds + #[prost(uint64, tag = "5")] + pub total_ns: u64, + /// / Concatenate elapsed time in nanoseconds + #[prost(uint64, optional, tag = "6")] + pub concat_ns: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct WarmupRequest { + /// / Batch to warmup on + #[prost(message, optional, tag = "1")] + pub batch: ::core::option::Option, + #[prost(uint32, optional, tag = "2")] + pub max_input_tokens: ::core::option::Option, + #[prost(uint32, tag = "3")] + pub max_prefill_tokens: u32, + #[prost(uint32, optional, tag = "4")] + pub max_total_tokens: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct WarmupResponse { + /// / Maximum number of tokens supported by the model + #[prost(uint32, optional, tag = "1")] + pub max_supported_total_tokens: ::core::option::Option, + /// / Maximum input tokens by clients should be equal to request value if it's set + /// / Otherwise warmup automatically allocates a value here + #[prost(uint32, tag = "2")] + pub max_input_tokens: u32, + /// / Maximum total tokens by clients should be equal to request value if it's set + /// / Otherwise warmup automatically allocates a value here + #[prost(uint32, tag = "3")] + pub max_total_tokens: u32, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum GrammarType { + None = 0, + Json = 1, + Regex = 2, +} +impl GrammarType { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + GrammarType::None => "GRAMMAR_TYPE_NONE", + GrammarType::Json => "GRAMMAR_TYPE_JSON", + GrammarType::Regex => "GRAMMAR_TYPE_REGEX", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "GRAMMAR_TYPE_NONE" => Some(Self::None), + "GRAMMAR_TYPE_JSON" => Some(Self::Json), + "GRAMMAR_TYPE_REGEX" => Some(Self::Regex), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum FinishReason { + Length = 0, + EosToken = 1, + StopSequence = 2, +} +impl FinishReason { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + FinishReason::Length => "FINISH_REASON_LENGTH", + FinishReason::EosToken => "FINISH_REASON_EOS_TOKEN", + FinishReason::StopSequence => "FINISH_REASON_STOP_SEQUENCE", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "FINISH_REASON_LENGTH" => Some(Self::Length), + "FINISH_REASON_EOS_TOKEN" => Some(Self::EosToken), + "FINISH_REASON_STOP_SEQUENCE" => Some(Self::StopSequence), + _ => None, + } + } +} +/// Generated client implementations. +pub mod text_generation_service_client { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::http::Uri; + use tonic::codegen::*; + #[derive(Debug, Clone)] + pub struct TextGenerationServiceClient { + inner: tonic::client::Grpc, + } + impl TextGenerationServiceClient { + /// Attempt to create a new client by connecting to a given endpoint. + pub async fn connect(dst: D) -> Result + where + D: TryInto, + D::Error: Into, + { + let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; + Ok(Self::new(conn)) + } + } + impl TextGenerationServiceClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + Send + 'static, + ::Error: Into + Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> TextGenerationServiceClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + >>::Error: + Into + Send + Sync, + { + TextGenerationServiceClient::new(InterceptedService::new(inner, interceptor)) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + /// / Model Info + pub async fn info( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = + http::uri::PathAndQuery::from_static("/generate.v3.TextGenerationService/Info"); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("generate.v3.TextGenerationService", "Info")); + self.inner.unary(req, path, codec).await + } + /// / Service discovery + pub async fn service_discovery( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> + { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v3.TextGenerationService/ServiceDiscovery", + ); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new( + "generate.v3.TextGenerationService", + "ServiceDiscovery", + )); + self.inner.unary(req, path, codec).await + } + /// / Empties batch cache + pub async fn clear_cache( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> + { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v3.TextGenerationService/ClearCache", + ); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new( + "generate.v3.TextGenerationService", + "ClearCache", + )); + self.inner.unary(req, path, codec).await + } + /// / Remove requests from a cached batch + pub async fn filter_batch( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> + { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v3.TextGenerationService/FilterBatch", + ); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new( + "generate.v3.TextGenerationService", + "FilterBatch", + )); + self.inner.unary(req, path, codec).await + } + /// / Warmup the model and compute max cache size + pub async fn warmup( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = + http::uri::PathAndQuery::from_static("/generate.v3.TextGenerationService/Warmup"); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new( + "generate.v3.TextGenerationService", + "Warmup", + )); + self.inner.unary(req, path, codec).await + } + /// / Prefill batch and decode first token + pub async fn prefill( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = + http::uri::PathAndQuery::from_static("/generate.v3.TextGenerationService/Prefill"); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new( + "generate.v3.TextGenerationService", + "Prefill", + )); + self.inner.unary(req, path, codec).await + } + /// / Decode token for a list of prefilled batches + pub async fn decode( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = + http::uri::PathAndQuery::from_static("/generate.v3.TextGenerationService/Decode"); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new( + "generate.v3.TextGenerationService", + "Decode", + )); + self.inner.unary(req, path, codec).await + } + /// / Health check + pub async fn health( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = + http::uri::PathAndQuery::from_static("/generate.v3.TextGenerationService/Health"); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new( + "generate.v3.TextGenerationService", + "Health", + )); + self.inner.unary(req, path, codec).await + } + } +} diff --git a/backends/client/src/v3/pb/mod.rs b/backends/client/src/v3/pb/mod.rs new file mode 100644 index 00000000..b5397d05 --- /dev/null +++ b/backends/client/src/v3/pb/mod.rs @@ -0,0 +1,6 @@ +// This file is @generated by prost-build. +pub mod generate { + pub mod v3 { + include!("generate.v3.rs"); + } +} diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index 854a5895..b8a9182c 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -101,11 +101,11 @@ impl ShardedClient { #[instrument(skip(self))] pub async fn warmup( &mut self, - max_input_length: u32, + max_input_length: Option, max_prefill_tokens: u32, - max_total_tokens: u32, + max_total_tokens: Option, max_batch_size: Option, - ) -> Result> { + ) -> Result<(Option, u32, u32)> { let futures: Vec<_> = self .clients .iter_mut() @@ -122,8 +122,10 @@ impl ShardedClient { let results = join_all(futures) .await .into_iter() - .collect::>>>()?; - Ok(results.into_iter().flatten().min()) + .collect::, u32, u32)>>>()?; + let first = results.first().expect("Expect at least 1 warmup result"); + assert!(results.iter().all(|&item| item == *first)); + Ok(*first) } /// Generate one token for each request in the given batch diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index fe810f24..f4942f64 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -108,20 +108,22 @@ impl Client { #[instrument(skip_all)] pub async fn warmup( &mut self, - max_input_length: u32, + max_input_tokens: Option, max_prefill_tokens: u32, - max_total_tokens: u32, + max_total_tokens: Option, max_batch_size: Option, - ) -> Result> { + ) -> Result<(Option, u32, u32)> { let mut n_tokens = 0; let mut requests = Vec::new(); // Create requests while n_tokens < max_prefill_tokens { - let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + let mut truncate = max_prefill_tokens - n_tokens; + if let Some(max_input_tokens) = max_input_tokens { + truncate = min(max_input_tokens, truncate); + } let mut input_chunks = Vec::new(); - input_chunks - .push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into()); + input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into()); if n_tokens == 0 { input_chunks.push( Chunk::Image(Image { @@ -137,7 +139,7 @@ impl Client { // been updated to support chunks. let mut inputs = String::new(); - inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + inputs.push_str(&"_test ".to_string().repeat(truncate as usize)); if n_tokens == 0 { // 1 request is enough to test vision heads. // Sending images on other queries messes up easily with truncation. @@ -146,6 +148,12 @@ impl Client { )); } + let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens { + max_total_tokens - truncate + } else { + 1 + }; + requests.push(Request { id: 0, inputs, @@ -175,7 +183,7 @@ impl Client { grammar_type: GrammarType::None as i32, }), stopping_parameters: Some(StoppingCriteriaParameters { - max_new_tokens: max_total_tokens - truncate, + max_new_tokens, stop_sequences: vec![], ignore_eos_token: true, }), @@ -183,7 +191,7 @@ impl Client { top_n_tokens: 20, adapter_id: None, }); - n_tokens += max_input_length; + n_tokens += truncate; // Check max_batch_size if Some(requests.len()) == max_batch_size { @@ -195,19 +203,23 @@ impl Client { id: 0, size: requests.len() as u32, requests, - max_tokens: max_input_length, + max_tokens: max_input_tokens.unwrap_or(0), max_blocks: 0, }; let request = tonic::Request::new(WarmupRequest { batch: Some(batch), - max_input_length, + max_input_tokens, max_prefill_tokens, max_total_tokens, }) .inject_context(); let response = self.stub.warmup(request).await?.into_inner(); - Ok(response.max_supported_total_tokens) + Ok(( + response.max_supported_total_tokens, + response.max_input_tokens, + response.max_total_tokens, + )) } /// Generate one token for each request in the given batch diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index e181cd28..ac916d94 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -102,11 +102,11 @@ impl ShardedClient { #[instrument(skip(self))] pub async fn warmup( &mut self, - max_input_length: u32, + max_input_length: Option, max_prefill_tokens: u32, - max_total_tokens: u32, + max_total_tokens: Option, max_batch_size: Option, - ) -> Result> { + ) -> Result<(Option, u32, u32)> { let futures: Vec<_> = self .clients .iter_mut() @@ -123,8 +123,11 @@ impl ShardedClient { let results = join_all(futures) .await .into_iter() - .collect::>>>()?; - Ok(results.into_iter().flatten().min()) + .collect::, u32, u32)>>>()?; + + let first = results.first().expect("Expect at least 1 warmup result"); + assert!(results.iter().all(|&item| item == *first)); + Ok(*first) } /// Generate one token for each request in the given batch diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index 7daf9eae..09137853 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -37,12 +37,17 @@ pub struct BackendInfo { pub attention_impl: String, #[schema(example = "1")] pub block_size: u32, + + #[schema(example = "30000")] + pub max_input_tokens: usize, + #[schema(example = "32000")] + pub max_total_tokens: usize, } #[allow(clippy::too_many_arguments)] pub async fn connect_backend( - max_input_tokens: usize, - max_total_tokens: usize, + max_input_tokens: Option, + max_total_tokens: Option, master_shard_uds_path: String, waiting_served_ratio: f32, max_batch_prefill_tokens: u32, @@ -51,14 +56,32 @@ pub async fn connect_backend( max_batch_size: Option, ) -> Result<(BackendV3, BackendInfo), V3Error> { // Helper function - let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { + let check_max_batch_total_tokens = |( + max_supported_batch_total_tokens, + shard_max_input_tokens, + shard_max_total_tokens, + ): (Option, u32, u32)| + -> Result<(u32, usize, usize), V3Error> { + if let Some(max_input_tokens) = max_input_tokens { + assert_eq!(max_input_tokens as u32, shard_max_input_tokens); + } + if let Some(max_total_tokens) = max_total_tokens { + assert_eq!(max_total_tokens as u32, shard_max_total_tokens); + } match max_supported_batch_total_tokens { // Older models do not support automatic max-batch-total-tokens None => { - let max_batch_total_tokens = max_batch_total_tokens - .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); + let max_batch_total_tokens = max_batch_total_tokens.unwrap_or( + 16000 + .max(shard_max_total_tokens) + .max(max_batch_prefill_tokens), + ); tracing::warn!("Model does not support automatic max batch total tokens"); - Ok(max_batch_total_tokens) + Ok(( + max_batch_total_tokens, + shard_max_input_tokens as usize, + shard_max_total_tokens as usize, + )) } // Flash attention models return their max supported total tokens Some(max_supported_batch_total_tokens) => { @@ -72,11 +95,15 @@ pub async fn connect_backend( "Inferred max batch total tokens: {max_supported_batch_total_tokens}" ); } - if max_total_tokens as u32 > max_supported_batch_total_tokens { - return Err(V3Error::NotEnoughMemory(max_total_tokens)); + if shard_max_total_tokens > max_supported_batch_total_tokens { + return Err(V3Error::NotEnoughMemory(shard_max_total_tokens as usize)); } - Ok(max_supported_batch_total_tokens) + Ok(( + max_supported_batch_total_tokens, + shard_max_input_tokens as usize, + shard_max_total_tokens as usize, + )) } } }; @@ -96,23 +123,25 @@ pub async fn connect_backend( // Warmup model tracing::info!("Warming up model"); - let max_batch_total_tokens = check_max_batch_total_tokens( - sharded_client - .warmup( - max_input_tokens as u32, - max_batch_prefill_tokens, - max_total_tokens as u32, - max_batch_size, - ) - .await - .map_err(V3Error::Warmup)?, - )?; + let answer = sharded_client + .warmup( + max_input_tokens.map(|p| p as u32), + max_batch_prefill_tokens, + max_total_tokens.map(|p| p as u32), + max_batch_size, + ) + .await + .map_err(V3Error::Warmup)?; + let (max_batch_total_tokens, max_input_tokens, max_total_tokens) = + check_max_batch_total_tokens(answer)?; tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens); let backend_info = BackendInfo { waiting_served_ratio, max_batch_total_tokens, + max_input_tokens, + max_total_tokens, max_waiting_tokens, max_batch_size, model_device_type: shard_info.device_type.clone(), diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index b4751bd5..bbd09954 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -18,10 +18,10 @@ struct Args { max_stop_sequences: usize, #[clap(default_value = "5", long, env)] max_top_n_tokens: u32, - #[clap(default_value = "1024", long, env)] - max_input_tokens: usize, - #[clap(default_value = "2048", long, env)] - max_total_tokens: usize, + #[clap(long, env)] + max_input_tokens: Option, + #[clap(long, env)] + max_total_tokens: Option, #[clap(default_value = "1.2", long, env)] waiting_served_ratio: f32, #[clap(default_value = "4096", long, env)] @@ -126,12 +126,6 @@ async fn main() -> Result<(), RouterError> { text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); // Validate args - if max_input_tokens >= max_total_tokens { - return Err(RouterError::ArgumentValidation( - "`max_input_tokens` must be < `max_total_tokens`".to_string(), - )); - } - if validation_workers == 0 { return Err(RouterError::ArgumentValidation( "`validation_workers` must be > 0".to_string(), @@ -160,6 +154,28 @@ async fn main() -> Result<(), RouterError> { // Validate remaining args now that the backend is known let support_chunking = backend_info.support_chunking; let max_batch_total_tokens = backend_info.max_batch_total_tokens; + + if max_input_tokens.is_none() { + tracing::info!( + "Maximum input tokens defaulted to {}", + backend_info.max_input_tokens + ); + } + if max_total_tokens.is_none() { + tracing::info!( + "Maximum total tokens defaulted to {}", + backend_info.max_total_tokens + ); + } + + let max_input_tokens = backend_info.max_input_tokens; + let max_total_tokens = backend_info.max_total_tokens; + if max_input_tokens >= max_total_tokens { + return Err(RouterError::ArgumentValidation( + "`max_input_tokens` must be < `max_total_tokens`".to_string(), + )); + } + if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking { return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); } diff --git a/launcher/src/main.rs b/launcher/src/main.rs index d9f569fd..a79467a5 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -137,10 +137,7 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> #[derive(Deserialize)] struct RawConfig { - max_position_embeddings: Option, - n_positions: Option, model_type: Option, - max_seq_len: Option, quantization_config: Option, n_embd: Option, hidden_size: Option, @@ -160,7 +157,6 @@ struct VisionConfig {} #[derive(Deserialize)] struct Config { - max_position_embeddings: Option, quantize: Option, head_dim: Option, model_type: Option, @@ -170,10 +166,6 @@ struct Config { impl From for Config { fn from(other: RawConfig) -> Self { - let max_position_embeddings = other - .max_position_embeddings - .or(other.max_seq_len) - .or(other.n_positions); let quantize = other.quantization_config.and_then(|q| q.quant_method); let head_dim = other.head_dim.or_else(|| { match (other.hidden_size, other.n_embd, other.num_attention_heads) { @@ -195,7 +187,6 @@ impl From for Config { let vision_config = other.vision_config; let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false); Config { - max_position_embeddings, quantize, head_dim, model_type, @@ -472,7 +463,7 @@ struct Args { /// for users. The larger this value, the longer prompt users can send which /// can impact the overall memory required to handle the load. /// Please note that some models have a finite range of sequence they can handle. - /// Default to min(max_position_embeddings - 1, 4095) + /// Default to min(max_allocatable, max_position_embeddings) - 1 #[clap(long, env)] max_input_tokens: Option, @@ -488,7 +479,7 @@ struct Args { /// `1511` max_new_tokens. /// The larger this value, the larger amount each request will be in your RAM /// and the less effective batching can be. - /// Default to min(max_position_embeddings, 4096) + /// Default to min(max_allocatable, max_position_embeddings) #[clap(long, env)] max_total_tokens: Option, @@ -718,9 +709,9 @@ fn shard_manager( cuda_memory_fraction: f32, rope_scaling: Option, rope_factor: Option, - max_total_tokens: usize, + max_total_tokens: Option, max_batch_size: Option, - max_input_tokens: usize, + max_input_tokens: Option, lora_adapters: Option, otlp_endpoint: Option, otlp_service_name: String, @@ -805,8 +796,10 @@ fn shard_manager( shard_args.push(otlp_service_name); // In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter. - shard_args.push("--max-input-tokens".to_string()); - shard_args.push(max_input_tokens.to_string()); + if let Some(max_input_tokens) = max_input_tokens { + shard_args.push("--max-input-tokens".to_string()); + shard_args.push(max_input_tokens.to_string()); + } // Copy current process env let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); @@ -854,10 +847,12 @@ fn shard_manager( envs.push(("ROPE_FACTOR".into(), factor.to_string().into())); } - envs.push(( - "MAX_TOTAL_TOKENS".into(), - max_total_tokens.to_string().into(), - )); + if let Some(max_total_tokens) = max_total_tokens { + envs.push(( + "MAX_TOTAL_TOKENS".into(), + max_total_tokens.to_string().into(), + )); + } if let Some(max_batch_size) = max_batch_size { envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into())); } @@ -1313,8 +1308,8 @@ fn spawn_shards( num_shard: usize, args: &Args, cuda_graphs: Vec, - max_total_tokens: usize, - max_input_tokens: usize, + max_total_tokens: Option, + max_input_tokens: Option, quantize: Option, max_log_level: LevelFilter, shutdown: Arc, @@ -1432,8 +1427,8 @@ fn compute_type(num_shard: usize) -> Option { fn spawn_webserver( num_shard: usize, args: Args, - max_input_tokens: usize, - max_total_tokens: usize, + max_input_tokens: Option, + max_total_tokens: Option, max_batch_prefill_tokens: u32, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, @@ -1452,10 +1447,6 @@ fn spawn_webserver( args.max_stop_sequences.to_string(), "--max-top-n-tokens".to_string(), args.max_top_n_tokens.to_string(), - "--max-input-tokens".to_string(), - max_input_tokens.to_string(), - "--max-total-tokens".to_string(), - max_total_tokens.to_string(), "--max-batch-prefill-tokens".to_string(), max_batch_prefill_tokens.to_string(), "--waiting-served-ratio".to_string(), @@ -1473,6 +1464,18 @@ fn spawn_webserver( "--tokenizer-name".to_string(), args.model_id, ]; + if let Some(max_input_tokens) = max_input_tokens { + router_args.extend_from_slice(&[ + "--max-input-tokens".to_string(), + max_input_tokens.to_string(), + ]); + } + if let Some(max_total_tokens) = max_total_tokens { + router_args.extend_from_slice(&[ + "--max-total-tokens".to_string(), + max_total_tokens.to_string(), + ]); + } // Pass usage stats flags to router router_args.push("--usage-stats".to_string()); @@ -1664,28 +1667,6 @@ fn main() -> Result<(), LauncherError> { let config: Option = get_config(&args.model_id, &args.revision).ok(); let quantize = config.as_ref().and_then(|c| c.quantize); // Quantization usually means you're even more RAM constrained. - let max_default = 4096; - - let max_position_embeddings = if let Some(config) = &config { - if let Some(max_position_embeddings) = config.max_position_embeddings { - if max_position_embeddings > max_default { - let max = max_position_embeddings; - if args.max_input_tokens.is_none() - && args.max_total_tokens.is_none() - && args.max_batch_prefill_tokens.is_none() - { - tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); - } - max_default - } else { - max_position_embeddings - } - } else { - max_default - } - } else { - max_default - }; let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters); tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}"); std::env::set_var("PREFIX_CACHING", prefix_caching); @@ -1698,35 +1679,26 @@ fn main() -> Result<(), LauncherError> { format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.", ))); } - (Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens, - (None, None) => { - let value = max_position_embeddings - 1; - tracing::info!("Default `max_input_tokens` to {value}"); - value - } - } - }; - let max_total_tokens = { - match args.max_total_tokens { - Some(max_total_tokens) => max_total_tokens, - None => { - let value = max_position_embeddings; - tracing::info!("Default `max_total_tokens` to {value}"); - value + (Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => { + Some(max_input_tokens) } + (None, None) => None, } }; + let max_total_tokens = args.max_total_tokens; let max_batch_prefill_tokens = { match args.max_batch_prefill_tokens { Some(max_batch_prefill_tokens) => max_batch_prefill_tokens, None => { - let value: u32 = if let Some(max_batch_size) = args.max_batch_size { - max_batch_size * max_input_tokens - } else { - // Adding some edge in order to account for potential block_size alignement - // issue. - max_input_tokens + 50 - } as u32; + // let value: u32 = if let Some(max_batch_size) = args.max_batch_size { + // max_batch_size * max_input_tokens + // } else { + // // Adding some edge in order to account for potential block_size alignement + // // issue. + // max_input_tokens + 50 + // } as u32; + // TODO figure out hardware optimal value + let value = 4096; tracing::info!("Default `max_batch_prefill_tokens` to {value}"); value } @@ -1734,10 +1706,12 @@ fn main() -> Result<(), LauncherError> { }; // Validate args - if max_input_tokens >= max_total_tokens { - return Err(LauncherError::ArgumentValidation( - "`max_input_tokens must be < `max_total_tokens`".to_string(), - )); + if let (Some(max_input_tokens), Some(max_total_tokens)) = (max_input_tokens, max_total_tokens) { + if max_input_tokens >= max_total_tokens { + return Err(LauncherError::ArgumentValidation( + format!("`max_input_tokens`({max_input_tokens}) must be < `max_total_tokens`({max_total_tokens})"), + )); + } } if matches!(args.quantize, Some(Quantization::Bitsandbytes)) { @@ -1792,11 +1766,13 @@ fn main() -> Result<(), LauncherError> { } if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { - if max_total_tokens as u32 > *max_batch_total_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - max_total_tokens, max_batch_total_tokens - ))); + if let Some(max_total_tokens) = max_total_tokens { + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", + max_total_tokens, max_batch_total_tokens + ))); + } } } diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index c91e7cc4..02980b6f 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -272,12 +272,18 @@ message DecodeResponse { message WarmupRequest { /// Batch to warmup on Batch batch = 1; - uint32 max_input_length = 2; + optional uint32 max_input_tokens = 2; uint32 max_prefill_tokens = 3; - uint32 max_total_tokens = 4; + optional uint32 max_total_tokens = 4; } message WarmupResponse { /// Maximum number of tokens supported by the model optional uint32 max_supported_total_tokens = 1; + /// Maximum input tokens by clients should be equal to request value if it's set + /// Otherwise warmup automatically allocates a value here + uint32 max_input_tokens = 2; + /// Maximum total tokens by clients should be equal to request value if it's set + /// Otherwise warmup automatically allocates a value here + uint32 max_total_tokens = 3; } diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index b1270b44..2b2dd940 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -78,6 +78,10 @@ tracer = trace.get_tracer(__name__) SLIDING_WINDOW: Optional[int] = None +def small_power_of_2(n: int): + return 1 << ((n - 1).bit_length() - 1) + + def set_sliding_window(sliding_window: int): global SLIDING_WINDOW SLIDING_WINDOW = sliding_window @@ -1377,11 +1381,40 @@ class FlashCausalLM(Model): self.cuda_graphs[bs]["speculative_logits"] = speculative_logits torch.cuda.synchronize() - def warmup(self, batch: FlashCausalLMBatch): + def warmup( + self, + batch: FlashCausalLMBatch, + max_input_tokens: Optional[int], + max_total_tokens: Optional[int], + ): # The warmup batch is the biggest batch we could ever receive self.kv_cache = [] empty_cache() + # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) + # Calculate the number of blocks that can be allocated with the free memory + dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() + cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size + total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size + + if max_total_tokens is None: + model_max_length = self.tokenizer.model_max_length + free_memory = get_free_memory(self.device, MEMORY_FRACTION) + spare_blocks = ( + # Leave 5% for some wiggle room + int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size) + + batch.num_blocks + ) + spare_blocks = small_power_of_2(spare_blocks) + + available_blocks = min(model_max_length, spare_blocks) + batch.num_blocks = available_blocks + batch.max_blocks = available_blocks + max_input_tokens = ( + available_blocks - 1 if max_input_tokens is None else max_input_tokens + ) + max_total_tokens = available_blocks + try: self.init_kv_cache( batch.num_blocks, @@ -1393,6 +1426,7 @@ class FlashCausalLM(Model): ) max_bt = batch.max_blocks max_s = max_bt * BLOCK_SIZE + batch_num_blocks = batch.num_blocks if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): torch.cuda.tunable.tuning_enable(False) @@ -1405,14 +1439,7 @@ class FlashCausalLM(Model): synchronize(self.device) - # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) - # Calculate the number of blocks that can be allocated with the free memory - dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() - cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size - total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size - free_memory = get_free_memory(self.device, MEMORY_FRACTION) - batch_num_blocks = batch.num_blocks if batch is not None else 0 num_blocks = ( # Leave 5% for some wiggle room @@ -1505,7 +1532,9 @@ class FlashCausalLM(Model): logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})." ) - return int(num_blocks * BLOCK_SIZE) + assert max_input_tokens is not None + assert max_total_tokens is not None + return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens def tunableop_warmup(self, seqlen: int): input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index b3630013..5790de41 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -128,9 +128,11 @@ class Model(ABC): ) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]: raise NotImplementedError - def warmup(self, batch: B) -> Optional[int]: + def warmup( + self, batch: B, max_input_tokens: Optional[int], max_total_tokens: Optional[int] + ) -> Tuple[Optional[int], int, int]: self.generate_token(batch) - return None + return None, 0, 0 def decode_token( self, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index aef00fb5..45b48df8 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -132,10 +132,22 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) - max_supported_total_tokens = self.model.warmup(batch) + + # Override default values with None for clearer semantics. + max_input_tokens = ( + request.max_input_tokens if request.HasField("max_input_tokens") else None + ) + max_total_tokens = ( + request.max_total_tokens if request.HasField("max_total_tokens") else None + ) + max_supported_total_tokens, max_input_tokens, max_total_tokens = ( + self.model.warmup(batch, max_input_tokens, max_total_tokens) + ) return generate_pb2.WarmupResponse( - max_supported_total_tokens=max_supported_total_tokens + max_supported_total_tokens=max_supported_total_tokens, + max_input_tokens=max_input_tokens, + max_total_tokens=max_total_tokens, ) async def Prefill(self, request, context):