diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 820d6696..15f424f0 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -33,7 +33,7 @@ fn compute_optimal(config: Option<&Config>, compute: Option<&ComputeType>) -> Op if let (Some(config), Some(compute)) = (config, compute) { if let (Some(f16_max_compute), Some(model_compute)) = (compute.f16_flop(), config.flop()) { tracing::debug!("MAx compute {f16_max_compute} model compute {model_compute}"); - let optimal_size = (f16_max_compute / model_compute / 2) as usize; + let optimal_size = (f16_max_compute / model_compute) as usize; if optimal_size > 100 { // Ignore calculations that's too low // Most likely an error diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 389736ce..a890f824 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -57,6 +57,7 @@ from text_generation_server.models.globals import ( ATTENTION, BLOCK_SIZE, CUDA_GRAPHS, + REQUEST_LOGPROBS, TGI_WIGGLE_ROOM, get_adapter_to_index, ) @@ -292,6 +293,10 @@ class FlashCausalLMBatch(Batch): for i, (r, tokenized_input) in enumerate( zip(pb.requests, batch_tokenized_inputs) ): + ### XXX: This consumes so much memory on long requests + ### Deactivating it by default seems like the best course. + if not REQUEST_LOGPROBS: + r.prefill_logprobs = False # request id -> idx in list mapping requests_idx_mapping[r.id] = i diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 4ac6a6b4..dfae8ed2 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,6 +5,7 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master +REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"} ATTENTION = os.environ["ATTENTION"] # default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {