Choosing input/total tokens automatically based on available VRAM?
This commit is contained in:
parent
7f54b7336a
commit
a1aac7843b
|
@ -0,0 +1,613 @@
|
||||||
|
// This file is @generated by prost-build.
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct HealthRequest {}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct HealthResponse {}
|
||||||
|
/// / Empty request
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct InfoRequest {}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct InfoResponse {
|
||||||
|
#[prost(bool, tag = "1")]
|
||||||
|
pub requires_padding: bool,
|
||||||
|
#[prost(string, tag = "2")]
|
||||||
|
pub dtype: ::prost::alloc::string::String,
|
||||||
|
#[prost(string, tag = "3")]
|
||||||
|
pub device_type: ::prost::alloc::string::String,
|
||||||
|
#[prost(uint32, optional, tag = "4")]
|
||||||
|
pub window_size: ::core::option::Option<u32>,
|
||||||
|
#[prost(uint32, tag = "5")]
|
||||||
|
pub speculate: u32,
|
||||||
|
}
|
||||||
|
/// / Empty request
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct ServiceDiscoveryRequest {}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct ServiceDiscoveryResponse {
|
||||||
|
/// / Other shards urls
|
||||||
|
#[prost(string, repeated, tag = "1")]
|
||||||
|
pub urls: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct ClearCacheRequest {
|
||||||
|
/// / Optional batch id
|
||||||
|
#[prost(uint64, optional, tag = "1")]
|
||||||
|
pub id: ::core::option::Option<u64>,
|
||||||
|
}
|
||||||
|
/// / Empty response
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct ClearCacheResponse {}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct NextTokenChooserParameters {
|
||||||
|
/// / exponential scaling output probability distribution
|
||||||
|
#[prost(float, tag = "1")]
|
||||||
|
pub temperature: f32,
|
||||||
|
/// / restricting to the k highest probability elements
|
||||||
|
#[prost(uint32, tag = "2")]
|
||||||
|
pub top_k: u32,
|
||||||
|
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
|
#[prost(float, tag = "3")]
|
||||||
|
pub top_p: f32,
|
||||||
|
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
|
#[prost(float, tag = "4")]
|
||||||
|
pub typical_p: f32,
|
||||||
|
/// / apply sampling on the logits
|
||||||
|
#[prost(bool, tag = "5")]
|
||||||
|
pub do_sample: bool,
|
||||||
|
/// / random seed for sampling
|
||||||
|
#[prost(uint64, tag = "6")]
|
||||||
|
pub seed: u64,
|
||||||
|
/// / repetition penalty
|
||||||
|
#[prost(float, tag = "7")]
|
||||||
|
pub repetition_penalty: f32,
|
||||||
|
/// / frequency penalty
|
||||||
|
#[prost(float, tag = "9")]
|
||||||
|
pub frequency_penalty: f32,
|
||||||
|
/// / token watermarking using "A Watermark for Large Language Models"
|
||||||
|
#[prost(bool, tag = "8")]
|
||||||
|
pub watermark: bool,
|
||||||
|
/// / grammar (applied if not empty)
|
||||||
|
#[prost(string, tag = "10")]
|
||||||
|
pub grammar: ::prost::alloc::string::String,
|
||||||
|
/// / grammar type
|
||||||
|
#[prost(enumeration = "GrammarType", tag = "11")]
|
||||||
|
pub grammar_type: i32,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct StoppingCriteriaParameters {
|
||||||
|
/// / Maximum number of generated tokens
|
||||||
|
#[prost(uint32, tag = "1")]
|
||||||
|
pub max_new_tokens: u32,
|
||||||
|
/// / Optional stopping sequences
|
||||||
|
#[prost(string, repeated, tag = "2")]
|
||||||
|
pub stop_sequences: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||||
|
/// / Ignore end of sequence token
|
||||||
|
/// / used for benchmarking
|
||||||
|
#[prost(bool, tag = "3")]
|
||||||
|
pub ignore_eos_token: bool,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct Request {
|
||||||
|
/// / Request ID
|
||||||
|
#[prost(uint64, tag = "1")]
|
||||||
|
pub id: u64,
|
||||||
|
/// / The generation context
|
||||||
|
#[prost(string, tag = "2")]
|
||||||
|
pub inputs: ::prost::alloc::string::String,
|
||||||
|
/// / Context truncation
|
||||||
|
#[prost(uint32, tag = "3")]
|
||||||
|
pub truncate: u32,
|
||||||
|
/// / Next Token Chooser Parameters
|
||||||
|
#[prost(message, optional, tag = "4")]
|
||||||
|
pub parameters: ::core::option::Option<NextTokenChooserParameters>,
|
||||||
|
/// / Stopping Criteria Parameters
|
||||||
|
#[prost(message, optional, tag = "5")]
|
||||||
|
pub stopping_parameters: ::core::option::Option<StoppingCriteriaParameters>,
|
||||||
|
/// / Return prefill logprobs
|
||||||
|
#[prost(bool, tag = "6")]
|
||||||
|
pub prefill_logprobs: bool,
|
||||||
|
/// / Return most likely n tokens
|
||||||
|
#[prost(uint32, tag = "7")]
|
||||||
|
pub top_n_tokens: u32,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct Batch {
|
||||||
|
/// / Batch ID
|
||||||
|
#[prost(uint64, tag = "1")]
|
||||||
|
pub id: u64,
|
||||||
|
/// / Individual requests
|
||||||
|
#[prost(message, repeated, tag = "2")]
|
||||||
|
pub requests: ::prost::alloc::vec::Vec<Request>,
|
||||||
|
/// / Batch size (==len(requests))
|
||||||
|
#[prost(uint32, tag = "3")]
|
||||||
|
pub size: u32,
|
||||||
|
/// / Maximum number of tokens this batch will grow to
|
||||||
|
#[prost(uint32, tag = "4")]
|
||||||
|
pub max_tokens: u32,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct CachedBatch {
|
||||||
|
/// / Batch ID
|
||||||
|
#[prost(uint64, tag = "1")]
|
||||||
|
pub id: u64,
|
||||||
|
/// / Individual requests ids
|
||||||
|
#[prost(uint64, repeated, tag = "2")]
|
||||||
|
pub request_ids: ::prost::alloc::vec::Vec<u64>,
|
||||||
|
/// / Batch size (==len(requests))
|
||||||
|
#[prost(uint32, tag = "3")]
|
||||||
|
pub size: u32,
|
||||||
|
/// / Maximum number of tokens this batch will grow to
|
||||||
|
#[prost(uint32, tag = "4")]
|
||||||
|
pub max_tokens: u32,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct GeneratedText {
|
||||||
|
/// / Output
|
||||||
|
#[prost(string, tag = "1")]
|
||||||
|
pub text: ::prost::alloc::string::String,
|
||||||
|
/// / Number of generated tokens
|
||||||
|
#[prost(uint32, tag = "2")]
|
||||||
|
pub generated_tokens: u32,
|
||||||
|
/// / Finish reason
|
||||||
|
#[prost(enumeration = "FinishReason", tag = "3")]
|
||||||
|
pub finish_reason: i32,
|
||||||
|
/// / Seed
|
||||||
|
#[prost(uint64, optional, tag = "4")]
|
||||||
|
pub seed: ::core::option::Option<u64>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct Tokens {
|
||||||
|
/// / Token IDs
|
||||||
|
#[prost(uint32, repeated, tag = "1")]
|
||||||
|
pub ids: ::prost::alloc::vec::Vec<u32>,
|
||||||
|
/// / Logprobs
|
||||||
|
#[prost(float, repeated, tag = "2")]
|
||||||
|
pub logprobs: ::prost::alloc::vec::Vec<f32>,
|
||||||
|
/// / tokens
|
||||||
|
#[prost(string, repeated, tag = "3")]
|
||||||
|
pub texts: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||||
|
/// / special
|
||||||
|
#[prost(bool, repeated, tag = "4")]
|
||||||
|
pub is_special: ::prost::alloc::vec::Vec<bool>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct Generation {
|
||||||
|
/// / Request ID
|
||||||
|
#[prost(uint64, tag = "1")]
|
||||||
|
pub request_id: u64,
|
||||||
|
/// / Prefill tokens (optional)
|
||||||
|
#[prost(message, optional, tag = "2")]
|
||||||
|
pub prefill_tokens: ::core::option::Option<Tokens>,
|
||||||
|
#[prost(message, optional, tag = "3")]
|
||||||
|
pub tokens: ::core::option::Option<Tokens>,
|
||||||
|
/// / Complete generated text
|
||||||
|
#[prost(message, optional, tag = "4")]
|
||||||
|
pub generated_text: ::core::option::Option<GeneratedText>,
|
||||||
|
/// / Top tokens
|
||||||
|
#[prost(message, repeated, tag = "5")]
|
||||||
|
pub top_tokens: ::prost::alloc::vec::Vec<Tokens>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct FilterBatchRequest {
|
||||||
|
/// / Batch ID
|
||||||
|
#[prost(uint64, tag = "1")]
|
||||||
|
pub batch_id: u64,
|
||||||
|
/// / Requests to keep
|
||||||
|
#[prost(uint64, repeated, tag = "2")]
|
||||||
|
pub request_ids: ::prost::alloc::vec::Vec<u64>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct FilterBatchResponse {
|
||||||
|
/// / Filtered Batch (cached)
|
||||||
|
#[prost(message, optional, tag = "1")]
|
||||||
|
pub batch: ::core::option::Option<CachedBatch>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct PrefillRequest {
|
||||||
|
/// / Batch
|
||||||
|
#[prost(message, optional, tag = "1")]
|
||||||
|
pub batch: ::core::option::Option<Batch>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct PrefillResponse {
|
||||||
|
/// / Generation
|
||||||
|
#[prost(message, repeated, tag = "1")]
|
||||||
|
pub generations: ::prost::alloc::vec::Vec<Generation>,
|
||||||
|
/// / Next batch (cached)
|
||||||
|
#[prost(message, optional, tag = "2")]
|
||||||
|
pub batch: ::core::option::Option<CachedBatch>,
|
||||||
|
/// / Forward elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, tag = "3")]
|
||||||
|
pub forward_ns: u64,
|
||||||
|
/// / Decode elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, tag = "4")]
|
||||||
|
pub decode_ns: u64,
|
||||||
|
/// / Total elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, tag = "5")]
|
||||||
|
pub total_ns: u64,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct DecodeRequest {
|
||||||
|
/// / Cached batches
|
||||||
|
#[prost(message, repeated, tag = "1")]
|
||||||
|
pub batches: ::prost::alloc::vec::Vec<CachedBatch>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct DecodeResponse {
|
||||||
|
/// / Decodes
|
||||||
|
#[prost(message, repeated, tag = "1")]
|
||||||
|
pub generations: ::prost::alloc::vec::Vec<Generation>,
|
||||||
|
/// / Next batch (cached)
|
||||||
|
#[prost(message, optional, tag = "2")]
|
||||||
|
pub batch: ::core::option::Option<CachedBatch>,
|
||||||
|
/// / Forward elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, tag = "3")]
|
||||||
|
pub forward_ns: u64,
|
||||||
|
/// / Decode elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, tag = "4")]
|
||||||
|
pub decode_ns: u64,
|
||||||
|
/// / Total elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, tag = "5")]
|
||||||
|
pub total_ns: u64,
|
||||||
|
/// / Concatenate elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, optional, tag = "6")]
|
||||||
|
pub concat_ns: ::core::option::Option<u64>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct WarmupRequest {
|
||||||
|
/// / Batch to warmup on
|
||||||
|
#[prost(message, optional, tag = "1")]
|
||||||
|
pub batch: ::core::option::Option<Batch>,
|
||||||
|
#[prost(uint32, tag = "2")]
|
||||||
|
pub max_input_length: u32,
|
||||||
|
#[prost(uint32, tag = "3")]
|
||||||
|
pub max_prefill_tokens: u32,
|
||||||
|
#[prost(uint32, tag = "4")]
|
||||||
|
pub max_total_tokens: u32,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct WarmupResponse {
|
||||||
|
/// / Maximum number of tokens supported by the model
|
||||||
|
#[prost(uint32, optional, tag = "1")]
|
||||||
|
pub max_supported_total_tokens: ::core::option::Option<u32>,
|
||||||
|
}
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
|
||||||
|
#[repr(i32)]
|
||||||
|
pub enum GrammarType {
|
||||||
|
None = 0,
|
||||||
|
Json = 1,
|
||||||
|
Regex = 2,
|
||||||
|
}
|
||||||
|
impl GrammarType {
|
||||||
|
/// String value of the enum field names used in the ProtoBuf definition.
|
||||||
|
///
|
||||||
|
/// The values are not transformed in any way and thus are considered stable
|
||||||
|
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
|
||||||
|
pub fn as_str_name(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
GrammarType::None => "GRAMMAR_TYPE_NONE",
|
||||||
|
GrammarType::Json => "GRAMMAR_TYPE_JSON",
|
||||||
|
GrammarType::Regex => "GRAMMAR_TYPE_REGEX",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Creates an enum from field names used in the ProtoBuf definition.
|
||||||
|
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
|
||||||
|
match value {
|
||||||
|
"GRAMMAR_TYPE_NONE" => Some(Self::None),
|
||||||
|
"GRAMMAR_TYPE_JSON" => Some(Self::Json),
|
||||||
|
"GRAMMAR_TYPE_REGEX" => Some(Self::Regex),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
|
||||||
|
#[repr(i32)]
|
||||||
|
pub enum FinishReason {
|
||||||
|
Length = 0,
|
||||||
|
EosToken = 1,
|
||||||
|
StopSequence = 2,
|
||||||
|
}
|
||||||
|
impl FinishReason {
|
||||||
|
/// String value of the enum field names used in the ProtoBuf definition.
|
||||||
|
///
|
||||||
|
/// The values are not transformed in any way and thus are considered stable
|
||||||
|
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
|
||||||
|
pub fn as_str_name(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
FinishReason::Length => "FINISH_REASON_LENGTH",
|
||||||
|
FinishReason::EosToken => "FINISH_REASON_EOS_TOKEN",
|
||||||
|
FinishReason::StopSequence => "FINISH_REASON_STOP_SEQUENCE",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Creates an enum from field names used in the ProtoBuf definition.
|
||||||
|
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
|
||||||
|
match value {
|
||||||
|
"FINISH_REASON_LENGTH" => Some(Self::Length),
|
||||||
|
"FINISH_REASON_EOS_TOKEN" => Some(Self::EosToken),
|
||||||
|
"FINISH_REASON_STOP_SEQUENCE" => Some(Self::StopSequence),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Generated client implementations.
|
||||||
|
pub mod text_generation_service_client {
|
||||||
|
#![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)]
|
||||||
|
use tonic::codegen::http::Uri;
|
||||||
|
use tonic::codegen::*;
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct TextGenerationServiceClient<T> {
|
||||||
|
inner: tonic::client::Grpc<T>,
|
||||||
|
}
|
||||||
|
impl TextGenerationServiceClient<tonic::transport::Channel> {
|
||||||
|
/// Attempt to create a new client by connecting to a given endpoint.
|
||||||
|
pub async fn connect<D>(dst: D) -> Result<Self, tonic::transport::Error>
|
||||||
|
where
|
||||||
|
D: TryInto<tonic::transport::Endpoint>,
|
||||||
|
D::Error: Into<StdError>,
|
||||||
|
{
|
||||||
|
let conn = tonic::transport::Endpoint::new(dst)?.connect().await?;
|
||||||
|
Ok(Self::new(conn))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<T> TextGenerationServiceClient<T>
|
||||||
|
where
|
||||||
|
T: tonic::client::GrpcService<tonic::body::BoxBody>,
|
||||||
|
T::Error: Into<StdError>,
|
||||||
|
T::ResponseBody: Body<Data = Bytes> + Send + 'static,
|
||||||
|
<T::ResponseBody as Body>::Error: Into<StdError> + Send,
|
||||||
|
{
|
||||||
|
pub fn new(inner: T) -> Self {
|
||||||
|
let inner = tonic::client::Grpc::new(inner);
|
||||||
|
Self { inner }
|
||||||
|
}
|
||||||
|
pub fn with_origin(inner: T, origin: Uri) -> Self {
|
||||||
|
let inner = tonic::client::Grpc::with_origin(inner, origin);
|
||||||
|
Self { inner }
|
||||||
|
}
|
||||||
|
pub fn with_interceptor<F>(
|
||||||
|
inner: T,
|
||||||
|
interceptor: F,
|
||||||
|
) -> TextGenerationServiceClient<InterceptedService<T, F>>
|
||||||
|
where
|
||||||
|
F: tonic::service::Interceptor,
|
||||||
|
T::ResponseBody: Default,
|
||||||
|
T: tonic::codegen::Service<
|
||||||
|
http::Request<tonic::body::BoxBody>,
|
||||||
|
Response = http::Response<
|
||||||
|
<T as tonic::client::GrpcService<tonic::body::BoxBody>>::ResponseBody,
|
||||||
|
>,
|
||||||
|
>,
|
||||||
|
<T as tonic::codegen::Service<http::Request<tonic::body::BoxBody>>>::Error:
|
||||||
|
Into<StdError> + Send + Sync,
|
||||||
|
{
|
||||||
|
TextGenerationServiceClient::new(InterceptedService::new(inner, interceptor))
|
||||||
|
}
|
||||||
|
/// Compress requests with the given encoding.
|
||||||
|
///
|
||||||
|
/// This requires the server to support it otherwise it might respond with an
|
||||||
|
/// error.
|
||||||
|
#[must_use]
|
||||||
|
pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
|
||||||
|
self.inner = self.inner.send_compressed(encoding);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// Enable decompressing responses.
|
||||||
|
#[must_use]
|
||||||
|
pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
|
||||||
|
self.inner = self.inner.accept_compressed(encoding);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// Limits the maximum size of a decoded message.
|
||||||
|
///
|
||||||
|
/// Default: `4MB`
|
||||||
|
#[must_use]
|
||||||
|
pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
|
||||||
|
self.inner = self.inner.max_decoding_message_size(limit);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// Limits the maximum size of an encoded message.
|
||||||
|
///
|
||||||
|
/// Default: `usize::MAX`
|
||||||
|
#[must_use]
|
||||||
|
pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
|
||||||
|
self.inner = self.inner.max_encoding_message_size(limit);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// / Model Info
|
||||||
|
pub async fn info(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::InfoRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::InfoResponse>, tonic::Status> {
|
||||||
|
self.inner.ready().await.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path =
|
||||||
|
http::uri::PathAndQuery::from_static("/generate.v2.TextGenerationService/Info");
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut()
|
||||||
|
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Info"));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Service discovery
|
||||||
|
pub async fn service_discovery(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::ServiceDiscoveryRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::ServiceDiscoveryResponse>, tonic::Status>
|
||||||
|
{
|
||||||
|
self.inner.ready().await.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path = http::uri::PathAndQuery::from_static(
|
||||||
|
"/generate.v2.TextGenerationService/ServiceDiscovery",
|
||||||
|
);
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut().insert(GrpcMethod::new(
|
||||||
|
"generate.v2.TextGenerationService",
|
||||||
|
"ServiceDiscovery",
|
||||||
|
));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Empties batch cache
|
||||||
|
pub async fn clear_cache(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::ClearCacheRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::ClearCacheResponse>, tonic::Status>
|
||||||
|
{
|
||||||
|
self.inner.ready().await.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path = http::uri::PathAndQuery::from_static(
|
||||||
|
"/generate.v2.TextGenerationService/ClearCache",
|
||||||
|
);
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut().insert(GrpcMethod::new(
|
||||||
|
"generate.v2.TextGenerationService",
|
||||||
|
"ClearCache",
|
||||||
|
));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Remove requests from a cached batch
|
||||||
|
pub async fn filter_batch(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::FilterBatchRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::FilterBatchResponse>, tonic::Status>
|
||||||
|
{
|
||||||
|
self.inner.ready().await.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path = http::uri::PathAndQuery::from_static(
|
||||||
|
"/generate.v2.TextGenerationService/FilterBatch",
|
||||||
|
);
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut().insert(GrpcMethod::new(
|
||||||
|
"generate.v2.TextGenerationService",
|
||||||
|
"FilterBatch",
|
||||||
|
));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Warmup the model and compute max cache size
|
||||||
|
pub async fn warmup(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::WarmupRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::WarmupResponse>, tonic::Status> {
|
||||||
|
self.inner.ready().await.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path =
|
||||||
|
http::uri::PathAndQuery::from_static("/generate.v2.TextGenerationService/Warmup");
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut().insert(GrpcMethod::new(
|
||||||
|
"generate.v2.TextGenerationService",
|
||||||
|
"Warmup",
|
||||||
|
));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Prefill batch and decode first token
|
||||||
|
pub async fn prefill(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::PrefillRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::PrefillResponse>, tonic::Status> {
|
||||||
|
self.inner.ready().await.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path =
|
||||||
|
http::uri::PathAndQuery::from_static("/generate.v2.TextGenerationService/Prefill");
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut().insert(GrpcMethod::new(
|
||||||
|
"generate.v2.TextGenerationService",
|
||||||
|
"Prefill",
|
||||||
|
));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Decode token for a list of prefilled batches
|
||||||
|
pub async fn decode(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::DecodeRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::DecodeResponse>, tonic::Status> {
|
||||||
|
self.inner.ready().await.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path =
|
||||||
|
http::uri::PathAndQuery::from_static("/generate.v2.TextGenerationService/Decode");
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut().insert(GrpcMethod::new(
|
||||||
|
"generate.v2.TextGenerationService",
|
||||||
|
"Decode",
|
||||||
|
));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Health check
|
||||||
|
pub async fn health(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::HealthRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::HealthResponse>, tonic::Status> {
|
||||||
|
self.inner.ready().await.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path =
|
||||||
|
http::uri::PathAndQuery::from_static("/generate.v2.TextGenerationService/Health");
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut().insert(GrpcMethod::new(
|
||||||
|
"generate.v2.TextGenerationService",
|
||||||
|
"Health",
|
||||||
|
));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,6 @@
|
||||||
|
// This file is @generated by prost-build.
|
||||||
|
pub mod generate {
|
||||||
|
pub mod v2 {
|
||||||
|
include!("generate.v2.rs");
|
||||||
|
}
|
||||||
|
}
|
|
@ -107,20 +107,22 @@ impl Client {
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub async fn warmup(
|
pub async fn warmup(
|
||||||
&mut self,
|
&mut self,
|
||||||
max_input_length: u32,
|
max_input_tokens: Option<u32>,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: Option<u32>,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<(Option<u32>, u32, u32)> {
|
||||||
let mut n_tokens = 0;
|
let mut n_tokens = 0;
|
||||||
let mut requests = Vec::new();
|
let mut requests = Vec::new();
|
||||||
// Create requests
|
// Create requests
|
||||||
while n_tokens < max_prefill_tokens {
|
while n_tokens < max_prefill_tokens {
|
||||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
let mut truncate = max_prefill_tokens - n_tokens;
|
||||||
|
if let Some(max_input_tokens) = max_input_tokens {
|
||||||
|
truncate = min(max_input_tokens, truncate);
|
||||||
|
}
|
||||||
|
|
||||||
let mut input_chunks = Vec::new();
|
let mut input_chunks = Vec::new();
|
||||||
input_chunks
|
input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into());
|
||||||
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
|
||||||
if n_tokens == 0 {
|
if n_tokens == 0 {
|
||||||
input_chunks.push(
|
input_chunks.push(
|
||||||
Chunk::Image(Image {
|
Chunk::Image(Image {
|
||||||
|
@ -136,7 +138,7 @@ impl Client {
|
||||||
// been updated to support chunks.
|
// been updated to support chunks.
|
||||||
|
|
||||||
let mut inputs = String::new();
|
let mut inputs = String::new();
|
||||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
inputs.push_str(&"_test ".to_string().repeat(truncate as usize));
|
||||||
if n_tokens == 0 {
|
if n_tokens == 0 {
|
||||||
// 1 request is enough to test vision heads.
|
// 1 request is enough to test vision heads.
|
||||||
// Sending images on other queries messes up easily with truncation.
|
// Sending images on other queries messes up easily with truncation.
|
||||||
|
@ -145,6 +147,12 @@ impl Client {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens {
|
||||||
|
max_total_tokens - truncate
|
||||||
|
} else {
|
||||||
|
1
|
||||||
|
};
|
||||||
|
|
||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: 0,
|
id: 0,
|
||||||
inputs,
|
inputs,
|
||||||
|
@ -175,7 +183,7 @@ impl Client {
|
||||||
grammar_type: GrammarType::None as i32,
|
grammar_type: GrammarType::None as i32,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: max_total_tokens - truncate,
|
max_new_tokens,
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
ignore_eos_token: true,
|
ignore_eos_token: true,
|
||||||
}),
|
}),
|
||||||
|
@ -183,7 +191,7 @@ impl Client {
|
||||||
top_n_tokens: 20,
|
top_n_tokens: 20,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
});
|
});
|
||||||
n_tokens += max_input_length;
|
n_tokens += truncate;
|
||||||
|
|
||||||
// Check max_batch_size
|
// Check max_batch_size
|
||||||
if Some(requests.len()) == max_batch_size {
|
if Some(requests.len()) == max_batch_size {
|
||||||
|
@ -195,19 +203,23 @@ impl Client {
|
||||||
id: 0,
|
id: 0,
|
||||||
size: requests.len() as u32,
|
size: requests.len() as u32,
|
||||||
requests,
|
requests,
|
||||||
max_tokens: max_input_length,
|
max_tokens: max_input_tokens.unwrap_or(0),
|
||||||
max_blocks: 0,
|
max_blocks: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let request = tonic::Request::new(WarmupRequest {
|
let request = tonic::Request::new(WarmupRequest {
|
||||||
batch: Some(batch),
|
batch: Some(batch),
|
||||||
max_input_length,
|
max_input_tokens,
|
||||||
max_prefill_tokens,
|
max_prefill_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
})
|
})
|
||||||
.inject_context();
|
.inject_context();
|
||||||
let response = self.stub.warmup(request).await?.into_inner();
|
let response = self.stub.warmup(request).await?.into_inner();
|
||||||
Ok(response.max_supported_total_tokens)
|
Ok((
|
||||||
|
response.max_supported_total_tokens,
|
||||||
|
response.max_input_tokens,
|
||||||
|
response.max_total_tokens,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
|
|
|
@ -0,0 +1,699 @@
|
||||||
|
// This file is @generated by prost-build.
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct HealthRequest {}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct HealthResponse {}
|
||||||
|
/// / Empty request
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct InfoRequest {}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct InfoResponse {
|
||||||
|
#[prost(bool, tag = "1")]
|
||||||
|
pub requires_padding: bool,
|
||||||
|
#[prost(string, tag = "2")]
|
||||||
|
pub dtype: ::prost::alloc::string::String,
|
||||||
|
#[prost(string, tag = "3")]
|
||||||
|
pub device_type: ::prost::alloc::string::String,
|
||||||
|
#[prost(uint32, optional, tag = "4")]
|
||||||
|
pub window_size: ::core::option::Option<u32>,
|
||||||
|
#[prost(uint32, tag = "5")]
|
||||||
|
pub speculate: u32,
|
||||||
|
#[prost(bool, tag = "6")]
|
||||||
|
pub support_chunking: bool,
|
||||||
|
#[prost(bool, tag = "7")]
|
||||||
|
pub use_prefix_caching: bool,
|
||||||
|
#[prost(string, tag = "8")]
|
||||||
|
pub attention_impl: ::prost::alloc::string::String,
|
||||||
|
#[prost(uint32, tag = "9")]
|
||||||
|
pub block_size: u32,
|
||||||
|
}
|
||||||
|
/// / Empty request
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct ServiceDiscoveryRequest {}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct ServiceDiscoveryResponse {
|
||||||
|
/// / Other shards urls
|
||||||
|
#[prost(string, repeated, tag = "1")]
|
||||||
|
pub urls: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct ClearCacheRequest {
|
||||||
|
/// / Optional batch id
|
||||||
|
#[prost(uint64, optional, tag = "1")]
|
||||||
|
pub id: ::core::option::Option<u64>,
|
||||||
|
}
|
||||||
|
/// / Empty response
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct ClearCacheResponse {}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct Image {
|
||||||
|
/// / Binary image data.
|
||||||
|
#[prost(bytes = "vec", tag = "1")]
|
||||||
|
pub data: ::prost::alloc::vec::Vec<u8>,
|
||||||
|
/// / Image MIME type.
|
||||||
|
#[prost(string, tag = "2")]
|
||||||
|
pub mimetype: ::prost::alloc::string::String,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct InputChunk {
|
||||||
|
#[prost(oneof = "input_chunk::Chunk", tags = "1, 2")]
|
||||||
|
pub chunk: ::core::option::Option<input_chunk::Chunk>,
|
||||||
|
}
|
||||||
|
/// Nested message and enum types in `InputChunk`.
|
||||||
|
pub mod input_chunk {
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Oneof)]
|
||||||
|
pub enum Chunk {
|
||||||
|
/// / Plain text data
|
||||||
|
#[prost(string, tag = "1")]
|
||||||
|
Text(::prost::alloc::string::String),
|
||||||
|
/// / Image data
|
||||||
|
#[prost(message, tag = "2")]
|
||||||
|
Image(super::Image),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct Input {
|
||||||
|
#[prost(message, repeated, tag = "1")]
|
||||||
|
pub chunks: ::prost::alloc::vec::Vec<InputChunk>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct NextTokenChooserParameters {
|
||||||
|
/// / exponential scaling output probability distribution
|
||||||
|
#[prost(float, tag = "1")]
|
||||||
|
pub temperature: f32,
|
||||||
|
/// / restricting to the k highest probability elements
|
||||||
|
#[prost(uint32, tag = "2")]
|
||||||
|
pub top_k: u32,
|
||||||
|
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
|
#[prost(float, tag = "3")]
|
||||||
|
pub top_p: f32,
|
||||||
|
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
|
#[prost(float, tag = "4")]
|
||||||
|
pub typical_p: f32,
|
||||||
|
/// / apply sampling on the logits
|
||||||
|
#[prost(bool, tag = "5")]
|
||||||
|
pub do_sample: bool,
|
||||||
|
/// / random seed for sampling
|
||||||
|
#[prost(uint64, tag = "6")]
|
||||||
|
pub seed: u64,
|
||||||
|
/// / repetition penalty
|
||||||
|
#[prost(float, tag = "7")]
|
||||||
|
pub repetition_penalty: f32,
|
||||||
|
/// / frequency penalty
|
||||||
|
#[prost(float, tag = "9")]
|
||||||
|
pub frequency_penalty: f32,
|
||||||
|
/// / token watermarking using "A Watermark for Large Language Models"
|
||||||
|
#[prost(bool, tag = "8")]
|
||||||
|
pub watermark: bool,
|
||||||
|
/// / grammar (applied if not empty)
|
||||||
|
#[prost(string, tag = "10")]
|
||||||
|
pub grammar: ::prost::alloc::string::String,
|
||||||
|
/// / grammar type
|
||||||
|
#[prost(enumeration = "GrammarType", tag = "11")]
|
||||||
|
pub grammar_type: i32,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct StoppingCriteriaParameters {
|
||||||
|
/// / Maximum number of generated tokens
|
||||||
|
#[prost(uint32, tag = "1")]
|
||||||
|
pub max_new_tokens: u32,
|
||||||
|
/// / Optional stopping sequences
|
||||||
|
#[prost(string, repeated, tag = "2")]
|
||||||
|
pub stop_sequences: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||||
|
/// / Ignore end of sequence token
|
||||||
|
/// / used for benchmarking
|
||||||
|
#[prost(bool, tag = "3")]
|
||||||
|
pub ignore_eos_token: bool,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct Request {
|
||||||
|
/// / Request ID
|
||||||
|
#[prost(uint64, tag = "1")]
|
||||||
|
pub id: u64,
|
||||||
|
/// / The generation context as chunks
|
||||||
|
#[prost(message, optional, tag = "8")]
|
||||||
|
pub input_chunks: ::core::option::Option<Input>,
|
||||||
|
/// / The generation context, stringified input_chunks
|
||||||
|
#[prost(string, tag = "2")]
|
||||||
|
pub inputs: ::prost::alloc::string::String,
|
||||||
|
/// / Context truncation
|
||||||
|
#[prost(uint32, tag = "3")]
|
||||||
|
pub truncate: u32,
|
||||||
|
/// / Next Token Chooser Parameters
|
||||||
|
#[prost(message, optional, tag = "4")]
|
||||||
|
pub parameters: ::core::option::Option<NextTokenChooserParameters>,
|
||||||
|
/// / Stopping Criteria Parameters
|
||||||
|
#[prost(message, optional, tag = "5")]
|
||||||
|
pub stopping_parameters: ::core::option::Option<StoppingCriteriaParameters>,
|
||||||
|
/// / Return prefill logprobs
|
||||||
|
#[prost(bool, tag = "6")]
|
||||||
|
pub prefill_logprobs: bool,
|
||||||
|
/// / Return most likely n tokens
|
||||||
|
#[prost(uint32, tag = "7")]
|
||||||
|
pub top_n_tokens: u32,
|
||||||
|
/// / Paged attention blocks
|
||||||
|
#[prost(uint32, repeated, tag = "9")]
|
||||||
|
pub blocks: ::prost::alloc::vec::Vec<u32>,
|
||||||
|
/// / Paged attention slots
|
||||||
|
#[prost(uint32, repeated, tag = "10")]
|
||||||
|
pub slots: ::prost::alloc::vec::Vec<u32>,
|
||||||
|
/// / LORA adapter index
|
||||||
|
#[prost(string, optional, tag = "11")]
|
||||||
|
pub adapter_id: ::core::option::Option<::prost::alloc::string::String>,
|
||||||
|
/// / Tokens that can be retrieved from the KV cache.
|
||||||
|
/// / This value is set for the first prefill and never reset
|
||||||
|
#[prost(uint32, tag = "12")]
|
||||||
|
pub cache_len: u32,
|
||||||
|
/// / Context truncation
|
||||||
|
#[prost(bool, tag = "13")]
|
||||||
|
pub add_special_tokens: bool,
|
||||||
|
/// / Chunk of tokens that must be computed for the first prefill
|
||||||
|
/// / This value is set for the first prefill and never reset
|
||||||
|
#[prost(uint32, optional, tag = "14")]
|
||||||
|
pub chunk_len: ::core::option::Option<u32>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct Batch {
|
||||||
|
/// / Batch ID
|
||||||
|
#[prost(uint64, tag = "1")]
|
||||||
|
pub id: u64,
|
||||||
|
/// / Individual requests
|
||||||
|
#[prost(message, repeated, tag = "2")]
|
||||||
|
pub requests: ::prost::alloc::vec::Vec<Request>,
|
||||||
|
/// / Batch size (==len(requests))
|
||||||
|
#[prost(uint32, tag = "3")]
|
||||||
|
pub size: u32,
|
||||||
|
/// / Maximum number of tokens this batch will grow to
|
||||||
|
#[prost(uint32, tag = "4")]
|
||||||
|
pub max_tokens: u32,
|
||||||
|
/// / Maximum number of Paged Attention blocks
|
||||||
|
#[prost(uint32, tag = "5")]
|
||||||
|
pub max_blocks: u32,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct CachedBatch {
|
||||||
|
/// / Batch ID
|
||||||
|
#[prost(uint64, tag = "1")]
|
||||||
|
pub id: u64,
|
||||||
|
/// / Individual requests ids
|
||||||
|
#[prost(uint64, repeated, tag = "2")]
|
||||||
|
pub request_ids: ::prost::alloc::vec::Vec<u64>,
|
||||||
|
/// / Batch size (==len(requests))
|
||||||
|
#[prost(uint32, tag = "3")]
|
||||||
|
pub size: u32,
|
||||||
|
/// / Maximum number of tokens this batch will grow to
|
||||||
|
#[prost(uint32, tag = "4")]
|
||||||
|
pub max_tokens: u32,
|
||||||
|
/// / Number of tokens in the next forward
|
||||||
|
#[prost(uint32, tag = "5")]
|
||||||
|
pub current_tokens: u32,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct GeneratedText {
|
||||||
|
/// / Output
|
||||||
|
#[prost(string, tag = "1")]
|
||||||
|
pub text: ::prost::alloc::string::String,
|
||||||
|
/// / Number of generated tokens
|
||||||
|
#[prost(uint32, tag = "2")]
|
||||||
|
pub generated_tokens: u32,
|
||||||
|
/// / Finish reason
|
||||||
|
#[prost(enumeration = "FinishReason", tag = "3")]
|
||||||
|
pub finish_reason: i32,
|
||||||
|
/// / Seed
|
||||||
|
#[prost(uint64, optional, tag = "4")]
|
||||||
|
pub seed: ::core::option::Option<u64>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct Tokens {
|
||||||
|
/// / Token IDs
|
||||||
|
#[prost(uint32, repeated, tag = "1")]
|
||||||
|
pub ids: ::prost::alloc::vec::Vec<u32>,
|
||||||
|
/// / Logprobs
|
||||||
|
#[prost(float, repeated, tag = "2")]
|
||||||
|
pub logprobs: ::prost::alloc::vec::Vec<f32>,
|
||||||
|
/// / tokens
|
||||||
|
#[prost(string, repeated, tag = "3")]
|
||||||
|
pub texts: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||||
|
/// / special
|
||||||
|
#[prost(bool, repeated, tag = "4")]
|
||||||
|
pub is_special: ::prost::alloc::vec::Vec<bool>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct Generation {
|
||||||
|
/// / Request ID
|
||||||
|
#[prost(uint64, tag = "1")]
|
||||||
|
pub request_id: u64,
|
||||||
|
/// / Prefill tokens (optional)
|
||||||
|
#[prost(message, optional, tag = "2")]
|
||||||
|
pub prefill_tokens: ::core::option::Option<Tokens>,
|
||||||
|
#[prost(message, optional, tag = "3")]
|
||||||
|
pub tokens: ::core::option::Option<Tokens>,
|
||||||
|
/// / Complete generated text
|
||||||
|
#[prost(message, optional, tag = "4")]
|
||||||
|
pub generated_text: ::core::option::Option<GeneratedText>,
|
||||||
|
/// / Top tokens
|
||||||
|
#[prost(message, repeated, tag = "5")]
|
||||||
|
pub top_tokens: ::prost::alloc::vec::Vec<Tokens>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct FilterBatchRequest {
|
||||||
|
/// / Batch ID
|
||||||
|
#[prost(uint64, tag = "1")]
|
||||||
|
pub batch_id: u64,
|
||||||
|
/// / Requests to keep
|
||||||
|
#[prost(uint64, repeated, tag = "2")]
|
||||||
|
pub request_ids: ::prost::alloc::vec::Vec<u64>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct FilterBatchResponse {
|
||||||
|
/// / Filtered Batch (cached)
|
||||||
|
#[prost(message, optional, tag = "1")]
|
||||||
|
pub batch: ::core::option::Option<CachedBatch>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct PrefillRequest {
|
||||||
|
/// / Batch
|
||||||
|
#[prost(message, optional, tag = "1")]
|
||||||
|
pub batch: ::core::option::Option<Batch>,
|
||||||
|
/// / Optional cached batch
|
||||||
|
#[prost(message, optional, tag = "2")]
|
||||||
|
pub cached_batch: ::core::option::Option<CachedBatch>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct PrefillResponse {
|
||||||
|
/// / Generation
|
||||||
|
#[prost(message, repeated, tag = "1")]
|
||||||
|
pub generations: ::prost::alloc::vec::Vec<Generation>,
|
||||||
|
/// / Next batch (cached)
|
||||||
|
#[prost(message, optional, tag = "2")]
|
||||||
|
pub batch: ::core::option::Option<CachedBatch>,
|
||||||
|
/// / Forward elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, tag = "3")]
|
||||||
|
pub forward_ns: u64,
|
||||||
|
/// / Decode elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, tag = "4")]
|
||||||
|
pub decode_ns: u64,
|
||||||
|
/// / Total elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, tag = "5")]
|
||||||
|
pub total_ns: u64,
|
||||||
|
/// / Concatenate elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, optional, tag = "6")]
|
||||||
|
pub concat_ns: ::core::option::Option<u64>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct DecodeRequest {
|
||||||
|
/// / Cached batches
|
||||||
|
#[prost(message, repeated, tag = "1")]
|
||||||
|
pub batches: ::prost::alloc::vec::Vec<CachedBatch>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct DecodeResponse {
|
||||||
|
/// / Decodes
|
||||||
|
#[prost(message, repeated, tag = "1")]
|
||||||
|
pub generations: ::prost::alloc::vec::Vec<Generation>,
|
||||||
|
/// / Next batch (cached)
|
||||||
|
#[prost(message, optional, tag = "2")]
|
||||||
|
pub batch: ::core::option::Option<CachedBatch>,
|
||||||
|
/// / Forward elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, tag = "3")]
|
||||||
|
pub forward_ns: u64,
|
||||||
|
/// / Decode elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, tag = "4")]
|
||||||
|
pub decode_ns: u64,
|
||||||
|
/// / Total elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, tag = "5")]
|
||||||
|
pub total_ns: u64,
|
||||||
|
/// / Concatenate elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, optional, tag = "6")]
|
||||||
|
pub concat_ns: ::core::option::Option<u64>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct WarmupRequest {
|
||||||
|
/// / Batch to warmup on
|
||||||
|
#[prost(message, optional, tag = "1")]
|
||||||
|
pub batch: ::core::option::Option<Batch>,
|
||||||
|
#[prost(uint32, optional, tag = "2")]
|
||||||
|
pub max_input_tokens: ::core::option::Option<u32>,
|
||||||
|
#[prost(uint32, tag = "3")]
|
||||||
|
pub max_prefill_tokens: u32,
|
||||||
|
#[prost(uint32, optional, tag = "4")]
|
||||||
|
pub max_total_tokens: ::core::option::Option<u32>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct WarmupResponse {
|
||||||
|
/// / Maximum number of tokens supported by the model
|
||||||
|
#[prost(uint32, optional, tag = "1")]
|
||||||
|
pub max_supported_total_tokens: ::core::option::Option<u32>,
|
||||||
|
/// / Maximum input tokens by clients should be equal to request value if it's set
|
||||||
|
/// / Otherwise warmup automatically allocates a value here
|
||||||
|
#[prost(uint32, tag = "2")]
|
||||||
|
pub max_input_tokens: u32,
|
||||||
|
/// / Maximum total tokens by clients should be equal to request value if it's set
|
||||||
|
/// / Otherwise warmup automatically allocates a value here
|
||||||
|
#[prost(uint32, tag = "3")]
|
||||||
|
pub max_total_tokens: u32,
|
||||||
|
}
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
|
||||||
|
#[repr(i32)]
|
||||||
|
pub enum GrammarType {
|
||||||
|
None = 0,
|
||||||
|
Json = 1,
|
||||||
|
Regex = 2,
|
||||||
|
}
|
||||||
|
impl GrammarType {
|
||||||
|
/// String value of the enum field names used in the ProtoBuf definition.
|
||||||
|
///
|
||||||
|
/// The values are not transformed in any way and thus are considered stable
|
||||||
|
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
|
||||||
|
pub fn as_str_name(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
GrammarType::None => "GRAMMAR_TYPE_NONE",
|
||||||
|
GrammarType::Json => "GRAMMAR_TYPE_JSON",
|
||||||
|
GrammarType::Regex => "GRAMMAR_TYPE_REGEX",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Creates an enum from field names used in the ProtoBuf definition.
|
||||||
|
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
|
||||||
|
match value {
|
||||||
|
"GRAMMAR_TYPE_NONE" => Some(Self::None),
|
||||||
|
"GRAMMAR_TYPE_JSON" => Some(Self::Json),
|
||||||
|
"GRAMMAR_TYPE_REGEX" => Some(Self::Regex),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
|
||||||
|
#[repr(i32)]
|
||||||
|
pub enum FinishReason {
|
||||||
|
Length = 0,
|
||||||
|
EosToken = 1,
|
||||||
|
StopSequence = 2,
|
||||||
|
}
|
||||||
|
impl FinishReason {
|
||||||
|
/// String value of the enum field names used in the ProtoBuf definition.
|
||||||
|
///
|
||||||
|
/// The values are not transformed in any way and thus are considered stable
|
||||||
|
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
|
||||||
|
pub fn as_str_name(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
FinishReason::Length => "FINISH_REASON_LENGTH",
|
||||||
|
FinishReason::EosToken => "FINISH_REASON_EOS_TOKEN",
|
||||||
|
FinishReason::StopSequence => "FINISH_REASON_STOP_SEQUENCE",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Creates an enum from field names used in the ProtoBuf definition.
|
||||||
|
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
|
||||||
|
match value {
|
||||||
|
"FINISH_REASON_LENGTH" => Some(Self::Length),
|
||||||
|
"FINISH_REASON_EOS_TOKEN" => Some(Self::EosToken),
|
||||||
|
"FINISH_REASON_STOP_SEQUENCE" => Some(Self::StopSequence),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Generated client implementations.
|
||||||
|
pub mod text_generation_service_client {
|
||||||
|
#![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)]
|
||||||
|
use tonic::codegen::http::Uri;
|
||||||
|
use tonic::codegen::*;
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct TextGenerationServiceClient<T> {
|
||||||
|
inner: tonic::client::Grpc<T>,
|
||||||
|
}
|
||||||
|
impl TextGenerationServiceClient<tonic::transport::Channel> {
|
||||||
|
/// Attempt to create a new client by connecting to a given endpoint.
|
||||||
|
pub async fn connect<D>(dst: D) -> Result<Self, tonic::transport::Error>
|
||||||
|
where
|
||||||
|
D: TryInto<tonic::transport::Endpoint>,
|
||||||
|
D::Error: Into<StdError>,
|
||||||
|
{
|
||||||
|
let conn = tonic::transport::Endpoint::new(dst)?.connect().await?;
|
||||||
|
Ok(Self::new(conn))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<T> TextGenerationServiceClient<T>
|
||||||
|
where
|
||||||
|
T: tonic::client::GrpcService<tonic::body::BoxBody>,
|
||||||
|
T::Error: Into<StdError>,
|
||||||
|
T::ResponseBody: Body<Data = Bytes> + Send + 'static,
|
||||||
|
<T::ResponseBody as Body>::Error: Into<StdError> + Send,
|
||||||
|
{
|
||||||
|
pub fn new(inner: T) -> Self {
|
||||||
|
let inner = tonic::client::Grpc::new(inner);
|
||||||
|
Self { inner }
|
||||||
|
}
|
||||||
|
pub fn with_origin(inner: T, origin: Uri) -> Self {
|
||||||
|
let inner = tonic::client::Grpc::with_origin(inner, origin);
|
||||||
|
Self { inner }
|
||||||
|
}
|
||||||
|
pub fn with_interceptor<F>(
|
||||||
|
inner: T,
|
||||||
|
interceptor: F,
|
||||||
|
) -> TextGenerationServiceClient<InterceptedService<T, F>>
|
||||||
|
where
|
||||||
|
F: tonic::service::Interceptor,
|
||||||
|
T::ResponseBody: Default,
|
||||||
|
T: tonic::codegen::Service<
|
||||||
|
http::Request<tonic::body::BoxBody>,
|
||||||
|
Response = http::Response<
|
||||||
|
<T as tonic::client::GrpcService<tonic::body::BoxBody>>::ResponseBody,
|
||||||
|
>,
|
||||||
|
>,
|
||||||
|
<T as tonic::codegen::Service<http::Request<tonic::body::BoxBody>>>::Error:
|
||||||
|
Into<StdError> + Send + Sync,
|
||||||
|
{
|
||||||
|
TextGenerationServiceClient::new(InterceptedService::new(inner, interceptor))
|
||||||
|
}
|
||||||
|
/// Compress requests with the given encoding.
|
||||||
|
///
|
||||||
|
/// This requires the server to support it otherwise it might respond with an
|
||||||
|
/// error.
|
||||||
|
#[must_use]
|
||||||
|
pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
|
||||||
|
self.inner = self.inner.send_compressed(encoding);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// Enable decompressing responses.
|
||||||
|
#[must_use]
|
||||||
|
pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
|
||||||
|
self.inner = self.inner.accept_compressed(encoding);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// Limits the maximum size of a decoded message.
|
||||||
|
///
|
||||||
|
/// Default: `4MB`
|
||||||
|
#[must_use]
|
||||||
|
pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
|
||||||
|
self.inner = self.inner.max_decoding_message_size(limit);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// Limits the maximum size of an encoded message.
|
||||||
|
///
|
||||||
|
/// Default: `usize::MAX`
|
||||||
|
#[must_use]
|
||||||
|
pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
|
||||||
|
self.inner = self.inner.max_encoding_message_size(limit);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// / Model Info
|
||||||
|
pub async fn info(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::InfoRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::InfoResponse>, tonic::Status> {
|
||||||
|
self.inner.ready().await.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path =
|
||||||
|
http::uri::PathAndQuery::from_static("/generate.v3.TextGenerationService/Info");
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut()
|
||||||
|
.insert(GrpcMethod::new("generate.v3.TextGenerationService", "Info"));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Service discovery
|
||||||
|
pub async fn service_discovery(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::ServiceDiscoveryRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::ServiceDiscoveryResponse>, tonic::Status>
|
||||||
|
{
|
||||||
|
self.inner.ready().await.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path = http::uri::PathAndQuery::from_static(
|
||||||
|
"/generate.v3.TextGenerationService/ServiceDiscovery",
|
||||||
|
);
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut().insert(GrpcMethod::new(
|
||||||
|
"generate.v3.TextGenerationService",
|
||||||
|
"ServiceDiscovery",
|
||||||
|
));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Empties batch cache
|
||||||
|
pub async fn clear_cache(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::ClearCacheRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::ClearCacheResponse>, tonic::Status>
|
||||||
|
{
|
||||||
|
self.inner.ready().await.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path = http::uri::PathAndQuery::from_static(
|
||||||
|
"/generate.v3.TextGenerationService/ClearCache",
|
||||||
|
);
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut().insert(GrpcMethod::new(
|
||||||
|
"generate.v3.TextGenerationService",
|
||||||
|
"ClearCache",
|
||||||
|
));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Remove requests from a cached batch
|
||||||
|
pub async fn filter_batch(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::FilterBatchRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::FilterBatchResponse>, tonic::Status>
|
||||||
|
{
|
||||||
|
self.inner.ready().await.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path = http::uri::PathAndQuery::from_static(
|
||||||
|
"/generate.v3.TextGenerationService/FilterBatch",
|
||||||
|
);
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut().insert(GrpcMethod::new(
|
||||||
|
"generate.v3.TextGenerationService",
|
||||||
|
"FilterBatch",
|
||||||
|
));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Warmup the model and compute max cache size
|
||||||
|
pub async fn warmup(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::WarmupRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::WarmupResponse>, tonic::Status> {
|
||||||
|
self.inner.ready().await.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path =
|
||||||
|
http::uri::PathAndQuery::from_static("/generate.v3.TextGenerationService/Warmup");
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut().insert(GrpcMethod::new(
|
||||||
|
"generate.v3.TextGenerationService",
|
||||||
|
"Warmup",
|
||||||
|
));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Prefill batch and decode first token
|
||||||
|
pub async fn prefill(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::PrefillRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::PrefillResponse>, tonic::Status> {
|
||||||
|
self.inner.ready().await.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path =
|
||||||
|
http::uri::PathAndQuery::from_static("/generate.v3.TextGenerationService/Prefill");
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut().insert(GrpcMethod::new(
|
||||||
|
"generate.v3.TextGenerationService",
|
||||||
|
"Prefill",
|
||||||
|
));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Decode token for a list of prefilled batches
|
||||||
|
pub async fn decode(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::DecodeRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::DecodeResponse>, tonic::Status> {
|
||||||
|
self.inner.ready().await.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path =
|
||||||
|
http::uri::PathAndQuery::from_static("/generate.v3.TextGenerationService/Decode");
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut().insert(GrpcMethod::new(
|
||||||
|
"generate.v3.TextGenerationService",
|
||||||
|
"Decode",
|
||||||
|
));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Health check
|
||||||
|
pub async fn health(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::HealthRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::HealthResponse>, tonic::Status> {
|
||||||
|
self.inner.ready().await.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path =
|
||||||
|
http::uri::PathAndQuery::from_static("/generate.v3.TextGenerationService/Health");
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut().insert(GrpcMethod::new(
|
||||||
|
"generate.v3.TextGenerationService",
|
||||||
|
"Health",
|
||||||
|
));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,6 @@
|
||||||
|
// This file is @generated by prost-build.
|
||||||
|
pub mod generate {
|
||||||
|
pub mod v3 {
|
||||||
|
include!("generate.v3.rs");
|
||||||
|
}
|
||||||
|
}
|
|
@ -101,11 +101,11 @@ impl ShardedClient {
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn warmup(
|
pub async fn warmup(
|
||||||
&mut self,
|
&mut self,
|
||||||
max_input_length: u32,
|
max_input_length: Option<u32>,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: Option<u32>,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<(Option<u32>, u32, u32)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
|
@ -122,8 +122,10 @@ impl ShardedClient {
|
||||||
let results = join_all(futures)
|
let results = join_all(futures)
|
||||||
.await
|
.await
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.collect::<Result<Vec<Option<u32>>>>()?;
|
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;
|
||||||
Ok(results.into_iter().flatten().min())
|
let first = results.first().expect("Expect at least 1 warmup result");
|
||||||
|
assert!(results.iter().all(|&item| item == *first));
|
||||||
|
Ok(*first)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
|
|
|
@ -108,20 +108,22 @@ impl Client {
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub async fn warmup(
|
pub async fn warmup(
|
||||||
&mut self,
|
&mut self,
|
||||||
max_input_length: u32,
|
max_input_tokens: Option<u32>,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: Option<u32>,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<(Option<u32>, u32, u32)> {
|
||||||
let mut n_tokens = 0;
|
let mut n_tokens = 0;
|
||||||
let mut requests = Vec::new();
|
let mut requests = Vec::new();
|
||||||
// Create requests
|
// Create requests
|
||||||
while n_tokens < max_prefill_tokens {
|
while n_tokens < max_prefill_tokens {
|
||||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
let mut truncate = max_prefill_tokens - n_tokens;
|
||||||
|
if let Some(max_input_tokens) = max_input_tokens {
|
||||||
|
truncate = min(max_input_tokens, truncate);
|
||||||
|
}
|
||||||
|
|
||||||
let mut input_chunks = Vec::new();
|
let mut input_chunks = Vec::new();
|
||||||
input_chunks
|
input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into());
|
||||||
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
|
||||||
if n_tokens == 0 {
|
if n_tokens == 0 {
|
||||||
input_chunks.push(
|
input_chunks.push(
|
||||||
Chunk::Image(Image {
|
Chunk::Image(Image {
|
||||||
|
@ -137,7 +139,7 @@ impl Client {
|
||||||
// been updated to support chunks.
|
// been updated to support chunks.
|
||||||
|
|
||||||
let mut inputs = String::new();
|
let mut inputs = String::new();
|
||||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
inputs.push_str(&"_test ".to_string().repeat(truncate as usize));
|
||||||
if n_tokens == 0 {
|
if n_tokens == 0 {
|
||||||
// 1 request is enough to test vision heads.
|
// 1 request is enough to test vision heads.
|
||||||
// Sending images on other queries messes up easily with truncation.
|
// Sending images on other queries messes up easily with truncation.
|
||||||
|
@ -146,6 +148,12 @@ impl Client {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens {
|
||||||
|
max_total_tokens - truncate
|
||||||
|
} else {
|
||||||
|
1
|
||||||
|
};
|
||||||
|
|
||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: 0,
|
id: 0,
|
||||||
inputs,
|
inputs,
|
||||||
|
@ -175,7 +183,7 @@ impl Client {
|
||||||
grammar_type: GrammarType::None as i32,
|
grammar_type: GrammarType::None as i32,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: max_total_tokens - truncate,
|
max_new_tokens,
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
ignore_eos_token: true,
|
ignore_eos_token: true,
|
||||||
}),
|
}),
|
||||||
|
@ -183,7 +191,7 @@ impl Client {
|
||||||
top_n_tokens: 20,
|
top_n_tokens: 20,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
});
|
});
|
||||||
n_tokens += max_input_length;
|
n_tokens += truncate;
|
||||||
|
|
||||||
// Check max_batch_size
|
// Check max_batch_size
|
||||||
if Some(requests.len()) == max_batch_size {
|
if Some(requests.len()) == max_batch_size {
|
||||||
|
@ -195,19 +203,23 @@ impl Client {
|
||||||
id: 0,
|
id: 0,
|
||||||
size: requests.len() as u32,
|
size: requests.len() as u32,
|
||||||
requests,
|
requests,
|
||||||
max_tokens: max_input_length,
|
max_tokens: max_input_tokens.unwrap_or(0),
|
||||||
max_blocks: 0,
|
max_blocks: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let request = tonic::Request::new(WarmupRequest {
|
let request = tonic::Request::new(WarmupRequest {
|
||||||
batch: Some(batch),
|
batch: Some(batch),
|
||||||
max_input_length,
|
max_input_tokens,
|
||||||
max_prefill_tokens,
|
max_prefill_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
})
|
})
|
||||||
.inject_context();
|
.inject_context();
|
||||||
let response = self.stub.warmup(request).await?.into_inner();
|
let response = self.stub.warmup(request).await?.into_inner();
|
||||||
Ok(response.max_supported_total_tokens)
|
Ok((
|
||||||
|
response.max_supported_total_tokens,
|
||||||
|
response.max_input_tokens,
|
||||||
|
response.max_total_tokens,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
|
|
|
@ -102,11 +102,11 @@ impl ShardedClient {
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn warmup(
|
pub async fn warmup(
|
||||||
&mut self,
|
&mut self,
|
||||||
max_input_length: u32,
|
max_input_length: Option<u32>,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: Option<u32>,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<(Option<u32>, u32, u32)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
|
@ -123,8 +123,11 @@ impl ShardedClient {
|
||||||
let results = join_all(futures)
|
let results = join_all(futures)
|
||||||
.await
|
.await
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.collect::<Result<Vec<Option<u32>>>>()?;
|
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;
|
||||||
Ok(results.into_iter().flatten().min())
|
|
||||||
|
let first = results.first().expect("Expect at least 1 warmup result");
|
||||||
|
assert!(results.iter().all(|&item| item == *first));
|
||||||
|
Ok(*first)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
|
|
|
@ -37,12 +37,17 @@ pub struct BackendInfo {
|
||||||
pub attention_impl: String,
|
pub attention_impl: String,
|
||||||
#[schema(example = "1")]
|
#[schema(example = "1")]
|
||||||
pub block_size: u32,
|
pub block_size: u32,
|
||||||
|
|
||||||
|
#[schema(example = "30000")]
|
||||||
|
pub max_input_tokens: usize,
|
||||||
|
#[schema(example = "32000")]
|
||||||
|
pub max_total_tokens: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn connect_backend(
|
pub async fn connect_backend(
|
||||||
max_input_tokens: usize,
|
max_input_tokens: Option<usize>,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: Option<usize>,
|
||||||
master_shard_uds_path: String,
|
master_shard_uds_path: String,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
|
@ -51,14 +56,32 @@ pub async fn connect_backend(
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<(BackendV3, BackendInfo), V3Error> {
|
) -> Result<(BackendV3, BackendInfo), V3Error> {
|
||||||
// Helper function
|
// Helper function
|
||||||
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
let check_max_batch_total_tokens = |(
|
||||||
|
max_supported_batch_total_tokens,
|
||||||
|
shard_max_input_tokens,
|
||||||
|
shard_max_total_tokens,
|
||||||
|
): (Option<u32>, u32, u32)|
|
||||||
|
-> Result<(u32, usize, usize), V3Error> {
|
||||||
|
if let Some(max_input_tokens) = max_input_tokens {
|
||||||
|
assert_eq!(max_input_tokens as u32, shard_max_input_tokens);
|
||||||
|
}
|
||||||
|
if let Some(max_total_tokens) = max_total_tokens {
|
||||||
|
assert_eq!(max_total_tokens as u32, shard_max_total_tokens);
|
||||||
|
}
|
||||||
match max_supported_batch_total_tokens {
|
match max_supported_batch_total_tokens {
|
||||||
// Older models do not support automatic max-batch-total-tokens
|
// Older models do not support automatic max-batch-total-tokens
|
||||||
None => {
|
None => {
|
||||||
let max_batch_total_tokens = max_batch_total_tokens
|
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
|
||||||
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
|
16000
|
||||||
|
.max(shard_max_total_tokens)
|
||||||
|
.max(max_batch_prefill_tokens),
|
||||||
|
);
|
||||||
tracing::warn!("Model does not support automatic max batch total tokens");
|
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||||
Ok(max_batch_total_tokens)
|
Ok((
|
||||||
|
max_batch_total_tokens,
|
||||||
|
shard_max_input_tokens as usize,
|
||||||
|
shard_max_total_tokens as usize,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
// Flash attention models return their max supported total tokens
|
// Flash attention models return their max supported total tokens
|
||||||
Some(max_supported_batch_total_tokens) => {
|
Some(max_supported_batch_total_tokens) => {
|
||||||
|
@ -72,11 +95,15 @@ pub async fn connect_backend(
|
||||||
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
if shard_max_total_tokens > max_supported_batch_total_tokens {
|
||||||
return Err(V3Error::NotEnoughMemory(max_total_tokens));
|
return Err(V3Error::NotEnoughMemory(shard_max_total_tokens as usize));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(max_supported_batch_total_tokens)
|
Ok((
|
||||||
|
max_supported_batch_total_tokens,
|
||||||
|
shard_max_input_tokens as usize,
|
||||||
|
shard_max_total_tokens as usize,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -96,23 +123,25 @@ pub async fn connect_backend(
|
||||||
|
|
||||||
// Warmup model
|
// Warmup model
|
||||||
tracing::info!("Warming up model");
|
tracing::info!("Warming up model");
|
||||||
let max_batch_total_tokens = check_max_batch_total_tokens(
|
let answer = sharded_client
|
||||||
sharded_client
|
.warmup(
|
||||||
.warmup(
|
max_input_tokens.map(|p| p as u32),
|
||||||
max_input_tokens as u32,
|
max_batch_prefill_tokens,
|
||||||
max_batch_prefill_tokens,
|
max_total_tokens.map(|p| p as u32),
|
||||||
max_total_tokens as u32,
|
max_batch_size,
|
||||||
max_batch_size,
|
)
|
||||||
)
|
.await
|
||||||
.await
|
.map_err(V3Error::Warmup)?;
|
||||||
.map_err(V3Error::Warmup)?,
|
let (max_batch_total_tokens, max_input_tokens, max_total_tokens) =
|
||||||
)?;
|
check_max_batch_total_tokens(answer)?;
|
||||||
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||||
metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens);
|
metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens);
|
||||||
|
|
||||||
let backend_info = BackendInfo {
|
let backend_info = BackendInfo {
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
model_device_type: shard_info.device_type.clone(),
|
model_device_type: shard_info.device_type.clone(),
|
||||||
|
|
|
@ -18,10 +18,10 @@ struct Args {
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
#[clap(default_value = "5", long, env)]
|
#[clap(default_value = "5", long, env)]
|
||||||
max_top_n_tokens: u32,
|
max_top_n_tokens: u32,
|
||||||
#[clap(default_value = "1024", long, env)]
|
#[clap(long, env)]
|
||||||
max_input_tokens: usize,
|
max_input_tokens: Option<usize>,
|
||||||
#[clap(default_value = "2048", long, env)]
|
#[clap(long, env)]
|
||||||
max_total_tokens: usize,
|
max_total_tokens: Option<usize>,
|
||||||
#[clap(default_value = "1.2", long, env)]
|
#[clap(default_value = "1.2", long, env)]
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
#[clap(default_value = "4096", long, env)]
|
#[clap(default_value = "4096", long, env)]
|
||||||
|
@ -126,12 +126,6 @@ async fn main() -> Result<(), RouterError> {
|
||||||
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||||
|
|
||||||
// Validate args
|
// Validate args
|
||||||
if max_input_tokens >= max_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(
|
|
||||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
if validation_workers == 0 {
|
if validation_workers == 0 {
|
||||||
return Err(RouterError::ArgumentValidation(
|
return Err(RouterError::ArgumentValidation(
|
||||||
"`validation_workers` must be > 0".to_string(),
|
"`validation_workers` must be > 0".to_string(),
|
||||||
|
@ -160,6 +154,28 @@ async fn main() -> Result<(), RouterError> {
|
||||||
// Validate remaining args now that the backend is known
|
// Validate remaining args now that the backend is known
|
||||||
let support_chunking = backend_info.support_chunking;
|
let support_chunking = backend_info.support_chunking;
|
||||||
let max_batch_total_tokens = backend_info.max_batch_total_tokens;
|
let max_batch_total_tokens = backend_info.max_batch_total_tokens;
|
||||||
|
|
||||||
|
if max_input_tokens.is_none() {
|
||||||
|
tracing::info!(
|
||||||
|
"Maximum input tokens defaulted to {}",
|
||||||
|
backend_info.max_input_tokens
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if max_total_tokens.is_none() {
|
||||||
|
tracing::info!(
|
||||||
|
"Maximum total tokens defaulted to {}",
|
||||||
|
backend_info.max_total_tokens
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_input_tokens = backend_info.max_input_tokens;
|
||||||
|
let max_total_tokens = backend_info.max_total_tokens;
|
||||||
|
if max_input_tokens >= max_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking {
|
if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking {
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||||
}
|
}
|
||||||
|
|
|
@ -137,10 +137,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct RawConfig {
|
struct RawConfig {
|
||||||
max_position_embeddings: Option<usize>,
|
|
||||||
n_positions: Option<usize>,
|
|
||||||
model_type: Option<String>,
|
model_type: Option<String>,
|
||||||
max_seq_len: Option<usize>,
|
|
||||||
quantization_config: Option<QuantizationConfig>,
|
quantization_config: Option<QuantizationConfig>,
|
||||||
n_embd: Option<usize>,
|
n_embd: Option<usize>,
|
||||||
hidden_size: Option<usize>,
|
hidden_size: Option<usize>,
|
||||||
|
@ -160,7 +157,6 @@ struct VisionConfig {}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct Config {
|
struct Config {
|
||||||
max_position_embeddings: Option<usize>,
|
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
head_dim: Option<usize>,
|
head_dim: Option<usize>,
|
||||||
model_type: Option<String>,
|
model_type: Option<String>,
|
||||||
|
@ -170,10 +166,6 @@ struct Config {
|
||||||
|
|
||||||
impl From<RawConfig> for Config {
|
impl From<RawConfig> for Config {
|
||||||
fn from(other: RawConfig) -> Self {
|
fn from(other: RawConfig) -> Self {
|
||||||
let max_position_embeddings = other
|
|
||||||
.max_position_embeddings
|
|
||||||
.or(other.max_seq_len)
|
|
||||||
.or(other.n_positions);
|
|
||||||
let quantize = other.quantization_config.and_then(|q| q.quant_method);
|
let quantize = other.quantization_config.and_then(|q| q.quant_method);
|
||||||
let head_dim = other.head_dim.or_else(|| {
|
let head_dim = other.head_dim.or_else(|| {
|
||||||
match (other.hidden_size, other.n_embd, other.num_attention_heads) {
|
match (other.hidden_size, other.n_embd, other.num_attention_heads) {
|
||||||
|
@ -195,7 +187,6 @@ impl From<RawConfig> for Config {
|
||||||
let vision_config = other.vision_config;
|
let vision_config = other.vision_config;
|
||||||
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
|
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
|
||||||
Config {
|
Config {
|
||||||
max_position_embeddings,
|
|
||||||
quantize,
|
quantize,
|
||||||
head_dim,
|
head_dim,
|
||||||
model_type,
|
model_type,
|
||||||
|
@ -472,7 +463,7 @@ struct Args {
|
||||||
/// for users. The larger this value, the longer prompt users can send which
|
/// for users. The larger this value, the longer prompt users can send which
|
||||||
/// can impact the overall memory required to handle the load.
|
/// can impact the overall memory required to handle the load.
|
||||||
/// Please note that some models have a finite range of sequence they can handle.
|
/// Please note that some models have a finite range of sequence they can handle.
|
||||||
/// Default to min(max_position_embeddings - 1, 4095)
|
/// Default to min(max_allocatable, max_position_embeddings) - 1
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
max_input_tokens: Option<usize>,
|
max_input_tokens: Option<usize>,
|
||||||
|
|
||||||
|
@ -488,7 +479,7 @@ struct Args {
|
||||||
/// `1511` max_new_tokens.
|
/// `1511` max_new_tokens.
|
||||||
/// The larger this value, the larger amount each request will be in your RAM
|
/// The larger this value, the larger amount each request will be in your RAM
|
||||||
/// and the less effective batching can be.
|
/// and the less effective batching can be.
|
||||||
/// Default to min(max_position_embeddings, 4096)
|
/// Default to min(max_allocatable, max_position_embeddings)
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
max_total_tokens: Option<usize>,
|
max_total_tokens: Option<usize>,
|
||||||
|
|
||||||
|
@ -718,9 +709,9 @@ fn shard_manager(
|
||||||
cuda_memory_fraction: f32,
|
cuda_memory_fraction: f32,
|
||||||
rope_scaling: Option<RopeScaling>,
|
rope_scaling: Option<RopeScaling>,
|
||||||
rope_factor: Option<f32>,
|
rope_factor: Option<f32>,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: Option<usize>,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
max_input_tokens: usize,
|
max_input_tokens: Option<usize>,
|
||||||
lora_adapters: Option<String>,
|
lora_adapters: Option<String>,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
otlp_service_name: String,
|
otlp_service_name: String,
|
||||||
|
@ -805,8 +796,10 @@ fn shard_manager(
|
||||||
shard_args.push(otlp_service_name);
|
shard_args.push(otlp_service_name);
|
||||||
|
|
||||||
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
|
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
|
||||||
shard_args.push("--max-input-tokens".to_string());
|
if let Some(max_input_tokens) = max_input_tokens {
|
||||||
shard_args.push(max_input_tokens.to_string());
|
shard_args.push("--max-input-tokens".to_string());
|
||||||
|
shard_args.push(max_input_tokens.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
// Copy current process env
|
// Copy current process env
|
||||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||||
|
@ -854,10 +847,12 @@ fn shard_manager(
|
||||||
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
|
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
envs.push((
|
if let Some(max_total_tokens) = max_total_tokens {
|
||||||
"MAX_TOTAL_TOKENS".into(),
|
envs.push((
|
||||||
max_total_tokens.to_string().into(),
|
"MAX_TOTAL_TOKENS".into(),
|
||||||
));
|
max_total_tokens.to_string().into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
if let Some(max_batch_size) = max_batch_size {
|
if let Some(max_batch_size) = max_batch_size {
|
||||||
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
|
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
|
||||||
}
|
}
|
||||||
|
@ -1313,8 +1308,8 @@ fn spawn_shards(
|
||||||
num_shard: usize,
|
num_shard: usize,
|
||||||
args: &Args,
|
args: &Args,
|
||||||
cuda_graphs: Vec<usize>,
|
cuda_graphs: Vec<usize>,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: Option<usize>,
|
||||||
max_input_tokens: usize,
|
max_input_tokens: Option<usize>,
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
max_log_level: LevelFilter,
|
max_log_level: LevelFilter,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
|
@ -1432,8 +1427,8 @@ fn compute_type(num_shard: usize) -> Option<String> {
|
||||||
fn spawn_webserver(
|
fn spawn_webserver(
|
||||||
num_shard: usize,
|
num_shard: usize,
|
||||||
args: Args,
|
args: Args,
|
||||||
max_input_tokens: usize,
|
max_input_tokens: Option<usize>,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: Option<usize>,
|
||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
shutdown_receiver: &mpsc::Receiver<()>,
|
shutdown_receiver: &mpsc::Receiver<()>,
|
||||||
|
@ -1452,10 +1447,6 @@ fn spawn_webserver(
|
||||||
args.max_stop_sequences.to_string(),
|
args.max_stop_sequences.to_string(),
|
||||||
"--max-top-n-tokens".to_string(),
|
"--max-top-n-tokens".to_string(),
|
||||||
args.max_top_n_tokens.to_string(),
|
args.max_top_n_tokens.to_string(),
|
||||||
"--max-input-tokens".to_string(),
|
|
||||||
max_input_tokens.to_string(),
|
|
||||||
"--max-total-tokens".to_string(),
|
|
||||||
max_total_tokens.to_string(),
|
|
||||||
"--max-batch-prefill-tokens".to_string(),
|
"--max-batch-prefill-tokens".to_string(),
|
||||||
max_batch_prefill_tokens.to_string(),
|
max_batch_prefill_tokens.to_string(),
|
||||||
"--waiting-served-ratio".to_string(),
|
"--waiting-served-ratio".to_string(),
|
||||||
|
@ -1473,6 +1464,18 @@ fn spawn_webserver(
|
||||||
"--tokenizer-name".to_string(),
|
"--tokenizer-name".to_string(),
|
||||||
args.model_id,
|
args.model_id,
|
||||||
];
|
];
|
||||||
|
if let Some(max_input_tokens) = max_input_tokens {
|
||||||
|
router_args.extend_from_slice(&[
|
||||||
|
"--max-input-tokens".to_string(),
|
||||||
|
max_input_tokens.to_string(),
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
if let Some(max_total_tokens) = max_total_tokens {
|
||||||
|
router_args.extend_from_slice(&[
|
||||||
|
"--max-total-tokens".to_string(),
|
||||||
|
max_total_tokens.to_string(),
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
// Pass usage stats flags to router
|
// Pass usage stats flags to router
|
||||||
router_args.push("--usage-stats".to_string());
|
router_args.push("--usage-stats".to_string());
|
||||||
|
@ -1664,28 +1667,6 @@ fn main() -> Result<(), LauncherError> {
|
||||||
let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
|
let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
|
||||||
let quantize = config.as_ref().and_then(|c| c.quantize);
|
let quantize = config.as_ref().and_then(|c| c.quantize);
|
||||||
// Quantization usually means you're even more RAM constrained.
|
// Quantization usually means you're even more RAM constrained.
|
||||||
let max_default = 4096;
|
|
||||||
|
|
||||||
let max_position_embeddings = if let Some(config) = &config {
|
|
||||||
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
|
||||||
if max_position_embeddings > max_default {
|
|
||||||
let max = max_position_embeddings;
|
|
||||||
if args.max_input_tokens.is_none()
|
|
||||||
&& args.max_total_tokens.is_none()
|
|
||||||
&& args.max_batch_prefill_tokens.is_none()
|
|
||||||
{
|
|
||||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
|
||||||
}
|
|
||||||
max_default
|
|
||||||
} else {
|
|
||||||
max_position_embeddings
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
max_default
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
max_default
|
|
||||||
};
|
|
||||||
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
|
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
|
||||||
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
|
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
|
||||||
std::env::set_var("PREFIX_CACHING", prefix_caching);
|
std::env::set_var("PREFIX_CACHING", prefix_caching);
|
||||||
|
@ -1698,35 +1679,26 @@ fn main() -> Result<(), LauncherError> {
|
||||||
format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
|
format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens,
|
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => {
|
||||||
(None, None) => {
|
Some(max_input_tokens)
|
||||||
let value = max_position_embeddings - 1;
|
|
||||||
tracing::info!("Default `max_input_tokens` to {value}");
|
|
||||||
value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let max_total_tokens = {
|
|
||||||
match args.max_total_tokens {
|
|
||||||
Some(max_total_tokens) => max_total_tokens,
|
|
||||||
None => {
|
|
||||||
let value = max_position_embeddings;
|
|
||||||
tracing::info!("Default `max_total_tokens` to {value}");
|
|
||||||
value
|
|
||||||
}
|
}
|
||||||
|
(None, None) => None,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
let max_total_tokens = args.max_total_tokens;
|
||||||
let max_batch_prefill_tokens = {
|
let max_batch_prefill_tokens = {
|
||||||
match args.max_batch_prefill_tokens {
|
match args.max_batch_prefill_tokens {
|
||||||
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
|
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
|
||||||
None => {
|
None => {
|
||||||
let value: u32 = if let Some(max_batch_size) = args.max_batch_size {
|
// let value: u32 = if let Some(max_batch_size) = args.max_batch_size {
|
||||||
max_batch_size * max_input_tokens
|
// max_batch_size * max_input_tokens
|
||||||
} else {
|
// } else {
|
||||||
// Adding some edge in order to account for potential block_size alignement
|
// // Adding some edge in order to account for potential block_size alignement
|
||||||
// issue.
|
// // issue.
|
||||||
max_input_tokens + 50
|
// max_input_tokens + 50
|
||||||
} as u32;
|
// } as u32;
|
||||||
|
// TODO figure out hardware optimal value
|
||||||
|
let value = 4096;
|
||||||
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
|
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
|
||||||
value
|
value
|
||||||
}
|
}
|
||||||
|
@ -1734,10 +1706,12 @@ fn main() -> Result<(), LauncherError> {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Validate args
|
// Validate args
|
||||||
if max_input_tokens >= max_total_tokens {
|
if let (Some(max_input_tokens), Some(max_total_tokens)) = (max_input_tokens, max_total_tokens) {
|
||||||
return Err(LauncherError::ArgumentValidation(
|
if max_input_tokens >= max_total_tokens {
|
||||||
"`max_input_tokens must be < `max_total_tokens`".to_string(),
|
return Err(LauncherError::ArgumentValidation(
|
||||||
));
|
format!("`max_input_tokens`({max_input_tokens}) must be < `max_total_tokens`({max_total_tokens})"),
|
||||||
|
));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
|
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
|
||||||
|
@ -1792,11 +1766,13 @@ fn main() -> Result<(), LauncherError> {
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
if let Some(max_total_tokens) = max_total_tokens {
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||||
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
return Err(LauncherError::ArgumentValidation(format!(
|
||||||
max_total_tokens, max_batch_total_tokens
|
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||||
)));
|
max_total_tokens, max_batch_total_tokens
|
||||||
|
)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -272,12 +272,18 @@ message DecodeResponse {
|
||||||
message WarmupRequest {
|
message WarmupRequest {
|
||||||
/// Batch to warmup on
|
/// Batch to warmup on
|
||||||
Batch batch = 1;
|
Batch batch = 1;
|
||||||
uint32 max_input_length = 2;
|
optional uint32 max_input_tokens = 2;
|
||||||
uint32 max_prefill_tokens = 3;
|
uint32 max_prefill_tokens = 3;
|
||||||
uint32 max_total_tokens = 4;
|
optional uint32 max_total_tokens = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
message WarmupResponse {
|
message WarmupResponse {
|
||||||
/// Maximum number of tokens supported by the model
|
/// Maximum number of tokens supported by the model
|
||||||
optional uint32 max_supported_total_tokens = 1;
|
optional uint32 max_supported_total_tokens = 1;
|
||||||
|
/// Maximum input tokens by clients should be equal to request value if it's set
|
||||||
|
/// Otherwise warmup automatically allocates a value here
|
||||||
|
uint32 max_input_tokens = 2;
|
||||||
|
/// Maximum total tokens by clients should be equal to request value if it's set
|
||||||
|
/// Otherwise warmup automatically allocates a value here
|
||||||
|
uint32 max_total_tokens = 3;
|
||||||
}
|
}
|
||||||
|
|
|
@ -78,6 +78,10 @@ tracer = trace.get_tracer(__name__)
|
||||||
SLIDING_WINDOW: Optional[int] = None
|
SLIDING_WINDOW: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
def small_power_of_2(n: int):
|
||||||
|
return 1 << ((n - 1).bit_length() - 1)
|
||||||
|
|
||||||
|
|
||||||
def set_sliding_window(sliding_window: int):
|
def set_sliding_window(sliding_window: int):
|
||||||
global SLIDING_WINDOW
|
global SLIDING_WINDOW
|
||||||
SLIDING_WINDOW = sliding_window
|
SLIDING_WINDOW = sliding_window
|
||||||
|
@ -1377,11 +1381,40 @@ class FlashCausalLM(Model):
|
||||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch):
|
def warmup(
|
||||||
|
self,
|
||||||
|
batch: FlashCausalLMBatch,
|
||||||
|
max_input_tokens: Optional[int],
|
||||||
|
max_total_tokens: Optional[int],
|
||||||
|
):
|
||||||
# The warmup batch is the biggest batch we could ever receive
|
# The warmup batch is the biggest batch we could ever receive
|
||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
|
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||||
|
# Calculate the number of blocks that can be allocated with the free memory
|
||||||
|
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
|
||||||
|
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||||
|
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||||
|
|
||||||
|
if max_total_tokens is None:
|
||||||
|
model_max_length = self.tokenizer.model_max_length
|
||||||
|
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
|
||||||
|
spare_blocks = (
|
||||||
|
# Leave 5% for some wiggle room
|
||||||
|
int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size)
|
||||||
|
+ batch.num_blocks
|
||||||
|
)
|
||||||
|
spare_blocks = small_power_of_2(spare_blocks)
|
||||||
|
|
||||||
|
available_blocks = min(model_max_length, spare_blocks)
|
||||||
|
batch.num_blocks = available_blocks
|
||||||
|
batch.max_blocks = available_blocks
|
||||||
|
max_input_tokens = (
|
||||||
|
available_blocks - 1 if max_input_tokens is None else max_input_tokens
|
||||||
|
)
|
||||||
|
max_total_tokens = available_blocks
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.init_kv_cache(
|
self.init_kv_cache(
|
||||||
batch.num_blocks,
|
batch.num_blocks,
|
||||||
|
@ -1393,6 +1426,7 @@ class FlashCausalLM(Model):
|
||||||
)
|
)
|
||||||
max_bt = batch.max_blocks
|
max_bt = batch.max_blocks
|
||||||
max_s = max_bt * BLOCK_SIZE
|
max_s = max_bt * BLOCK_SIZE
|
||||||
|
batch_num_blocks = batch.num_blocks
|
||||||
|
|
||||||
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||||
torch.cuda.tunable.tuning_enable(False)
|
torch.cuda.tunable.tuning_enable(False)
|
||||||
|
@ -1405,14 +1439,7 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
synchronize(self.device)
|
synchronize(self.device)
|
||||||
|
|
||||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
|
||||||
# Calculate the number of blocks that can be allocated with the free memory
|
|
||||||
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
|
|
||||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
|
||||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
|
||||||
|
|
||||||
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
|
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
|
||||||
batch_num_blocks = batch.num_blocks if batch is not None else 0
|
|
||||||
|
|
||||||
num_blocks = (
|
num_blocks = (
|
||||||
# Leave 5% for some wiggle room
|
# Leave 5% for some wiggle room
|
||||||
|
@ -1505,7 +1532,9 @@ class FlashCausalLM(Model):
|
||||||
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
|
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
|
||||||
)
|
)
|
||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE)
|
assert max_input_tokens is not None
|
||||||
|
assert max_total_tokens is not None
|
||||||
|
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||||
|
|
||||||
def tunableop_warmup(self, seqlen: int):
|
def tunableop_warmup(self, seqlen: int):
|
||||||
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
||||||
|
|
|
@ -128,9 +128,11 @@ class Model(ABC):
|
||||||
) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
|
) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def warmup(self, batch: B) -> Optional[int]:
|
def warmup(
|
||||||
|
self, batch: B, max_input_tokens: Optional[int], max_total_tokens: Optional[int]
|
||||||
|
) -> Tuple[Optional[int], int, int]:
|
||||||
self.generate_token(batch)
|
self.generate_token(batch)
|
||||||
return None
|
return None, 0, 0
|
||||||
|
|
||||||
def decode_token(
|
def decode_token(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -132,10 +132,22 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
)
|
)
|
||||||
max_supported_total_tokens = self.model.warmup(batch)
|
|
||||||
|
# Override default values with None for clearer semantics.
|
||||||
|
max_input_tokens = (
|
||||||
|
request.max_input_tokens if request.HasField("max_input_tokens") else None
|
||||||
|
)
|
||||||
|
max_total_tokens = (
|
||||||
|
request.max_total_tokens if request.HasField("max_total_tokens") else None
|
||||||
|
)
|
||||||
|
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
|
||||||
|
self.model.warmup(batch, max_input_tokens, max_total_tokens)
|
||||||
|
)
|
||||||
|
|
||||||
return generate_pb2.WarmupResponse(
|
return generate_pb2.WarmupResponse(
|
||||||
max_supported_total_tokens=max_supported_total_tokens
|
max_supported_total_tokens=max_supported_total_tokens,
|
||||||
|
max_input_tokens=max_input_tokens,
|
||||||
|
max_total_tokens=max_total_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
|
|
Loading…
Reference in New Issue