diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index f332ab5..9cd9ed8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -188,9 +188,10 @@ class FlashCausalLMBatch(Batch): position_ids.append(self.position_ids[idx]) cu_seqlens.append(cumulative_length + request_input_length) max_seqlen = max(max_seqlen, request_input_length) + # True index for past + past_key_values.append(self.past_key_values[2 * idx]) + if not single_request: - # True index for past - past_key_values.append(self.past_key_values[2 * idx]) # Add one padding past_key_values.append(self.past_pad) @@ -209,7 +210,7 @@ class FlashCausalLMBatch(Batch): if single_request: # Preallocate tensor for bs = 1 case past_key_values = torch.nn.functional.pad( - self.past_key_values[0], + past_key_values[0], ( 0, 0,