From 5f32dea1e2e6f1b4a86773ef7f1e8861fca5c61d Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 17 Oct 2024 08:49:02 -0400 Subject: [PATCH] fix: prefer inplace softmax to avoid copy (#2661) * fix: prefer inplace softmax to avoid copy * Update server/text_generation_server/models/flash_causal_lm.py Co-authored-by: Nicolas Patry --------- Co-authored-by: Nicolas Patry --- server/text_generation_server/models/flash_causal_lm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index c9b7decd..7018edb1 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1922,8 +1922,9 @@ class FlashCausalLM(Model): batch.adapter_meta.adapter_indices = next_adapter_indices if prefill and prefill_logprobs: - # Get prefill logprobs - prefill_logprobs_tensor = torch.log_softmax(out, -1) + # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size)) + torch.log_softmax(out, -1, out=out) + prefill_logprobs_tensor = out prefill_logprobs = torch.gather( prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1) )