diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a79467a5..1c2e5e01 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -137,7 +137,10 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> #[derive(Deserialize)] struct RawConfig { + max_position_embeddings: Option, + n_positions: Option, model_type: Option, + max_seq_len: Option, quantization_config: Option, n_embd: Option, hidden_size: Option, @@ -157,6 +160,7 @@ struct VisionConfig {} #[derive(Deserialize)] struct Config { + max_position_embeddings: Option, quantize: Option, head_dim: Option, model_type: Option, @@ -166,6 +170,10 @@ struct Config { impl From for Config { fn from(other: RawConfig) -> Self { + let max_position_embeddings = other + .max_position_embeddings + .or(other.max_seq_len) + .or(other.n_positions); let quantize = other.quantization_config.and_then(|q| q.quant_method); let head_dim = other.head_dim.or_else(|| { match (other.hidden_size, other.n_embd, other.num_attention_heads) { @@ -187,6 +195,7 @@ impl From for Config { let vision_config = other.vision_config; let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false); Config { + max_position_embeddings, quantize, head_dim, model_type, @@ -479,7 +488,7 @@ struct Args { /// `1511` max_new_tokens. /// The larger this value, the larger amount each request will be in your RAM /// and the less effective batching can be. - /// Default to min(max_allocatable, max_position_embeddings) + /// Default to min(max_position_embeddings, 4096) #[clap(long, env)] max_total_tokens: Option, @@ -1667,6 +1676,28 @@ fn main() -> Result<(), LauncherError> { let config: Option = get_config(&args.model_id, &args.revision).ok(); let quantize = config.as_ref().and_then(|c| c.quantize); // Quantization usually means you're even more RAM constrained. + let max_default = 4096; + + let max_position_embeddings = if let Some(config) = &config { + if let Some(max_position_embeddings) = config.max_position_embeddings { + if max_position_embeddings > max_default { + let max = max_position_embeddings; + if args.max_input_tokens.is_none() + && args.max_total_tokens.is_none() + && args.max_batch_prefill_tokens.is_none() + { + tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); + } + max_default + } else { + max_position_embeddings + } + } else { + max_default + } + } else { + max_default + }; let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters); tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}"); std::env::set_var("PREFIX_CACHING", prefix_caching); @@ -1690,15 +1721,8 @@ fn main() -> Result<(), LauncherError> { match args.max_batch_prefill_tokens { Some(max_batch_prefill_tokens) => max_batch_prefill_tokens, None => { - // let value: u32 = if let Some(max_batch_size) = args.max_batch_size { - // max_batch_size * max_input_tokens - // } else { - // // Adding some edge in order to account for potential block_size alignement - // // issue. - // max_input_tokens + 50 - // } as u32; // TODO figure out hardware optimal value - let value = 4096; + let value = 4096.min(max_position_embeddings as u32); tracing::info!("Default `max_batch_prefill_tokens` to {value}"); value } diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index bbff0243..b0de59bf 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1412,7 +1412,7 @@ class FlashCausalLM(Model): if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): torch.cuda.tunable.tuning_enable(False) - _, batch, _ = self.generate_token(batch) + _, _batch, _ = self.generate_token(batch) except torch.cuda.OutOfMemoryError as e: raise RuntimeError( f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. " @@ -1442,14 +1442,14 @@ class FlashCausalLM(Model): max_total_tokens = num_blocks * BLOCK_SIZE else: - max_total_tokens = sum(len(input_ids) for input_ids in batch.input_ids) + max_total_tokens = sum(batch.cache_lengths) max_input_tokens = ( max_total_tokens - 1 if max_input_tokens is None else max_input_tokens ) - del batch + del _batch, batch self.kv_cache = [] empty_cache() diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 5790de41..c75592c1 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -132,7 +132,13 @@ class Model(ABC): self, batch: B, max_input_tokens: Optional[int], max_total_tokens: Optional[int] ) -> Tuple[Optional[int], int, int]: self.generate_token(batch) - return None, 0, 0 + total = sum(len(i) for i in batch.input_ids) + if max_total_tokens is None: + max_total_tokens = total + + if max_input_tokens is None: + max_input_tokens = max_total_tokens - 1 + return None, max_input_tokens, max_total_tokens def decode_token( self,