From 4b460e72fb0df7ef451e1f5a48b4cb783d42b9c2 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 21 Apr 2023 20:26:01 +0200 Subject: [PATCH] fix(server): fix flash batch filtering (#220) --- server/text_generation_server/models/flash_causal_lm.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index f332ab5..9cd9ed8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -188,9 +188,10 @@ class FlashCausalLMBatch(Batch): position_ids.append(self.position_ids[idx]) cu_seqlens.append(cumulative_length + request_input_length) max_seqlen = max(max_seqlen, request_input_length) + # True index for past + past_key_values.append(self.past_key_values[2 * idx]) + if not single_request: - # True index for past - past_key_values.append(self.past_key_values[2 * idx]) # Add one padding past_key_values.append(self.past_pad) @@ -209,7 +210,7 @@ class FlashCausalLMBatch(Batch): if single_request: # Preallocate tensor for bs = 1 case past_key_values = torch.nn.functional.pad( - self.past_key_values[0], + past_key_values[0], ( 0, 0,