diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 4a072e31..6e054185 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -33,7 +33,7 @@ fn compute_optimal(config: Option<&Config>, compute: Option<&ComputeType>) -> Op if let (Some(config), Some(compute)) = (config, compute) { if let (Some(f16_max_compute), Some(model_compute)) = (compute.f16_flop(), config.flop()) { tracing::debug!("MAx compute {f16_max_compute} model compute {model_compute}"); - let optimal_size = (f16_max_compute / model_compute) as usize; + let optimal_size = (f16_max_compute / model_compute / 2) as usize; if optimal_size > 100 { // Ignore calculations that's too low // Most likely an error @@ -1484,6 +1484,8 @@ impl ComputeType { let card_flop = match &self.card[..] { // https://www.nvidia.com/en-us/data-center/l4/ "nvidia-l4" => Some(121 * 10u64.pow(12)), + // https://www.nvidia.com/en-us/data-center/products/a10-gpu/ + "nvidia-a10g" => Some(125 * 10u64.pow(12)), card => { tracing::warn!("Unkown compute for card {card}"); None diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 6e905b4a..389736ce 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1534,12 +1534,16 @@ class FlashCausalLM(Model): ) batch_num_blocks = batch.num_blocks + num_tokens = batch.to_pb().current_tokens + logger.info(f"BLOCKS {batch.num_blocks}") + free_memory = get_free_memory(self.device, MEMORY_FRACTION) + logger.info(f"Free memory {free_memory}") if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): torch.cuda.tunable.tuning_enable(False) _, _batch, _ = self.generate_token(batch) except torch.cuda.OutOfMemoryError as e: raise RuntimeError( - f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. " + f"Not enough memory to handle {num_tokens} prefill tokens. " f"You need to decrease `--max-batch-prefill-tokens`" ) from e @@ -2079,6 +2083,10 @@ class FlashCausalLM(Model): if prefill and prefill_logprobs: # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size)) + free_memory = get_free_memory(self.device, MEMORY_FRACTION) + logger.info(f"Free memory {free_memory / 1e9}GB") + logmemory = out.nelement() * out.element_size() + logger.info(f"Log memory {logmemory / 1e9}GB") torch.log_softmax(out, -1, out=out) prefill_logprobs_tensor = out prefill_logprobs = torch.gather(