diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7e048b7..c44dd57 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -453,7 +453,10 @@ class FlashCausalLM(Model): ) # Set in batch in case it needs to be used later in concatenate() batch.past_pad = self.past_pad - if len(batch) != 1: + if len(batch) == 1: + # present is already pre-padded + batch.past_key_values = present + else: # Add padding after each sequence # This will have the correct shape after the final past_key_values concatenation before the model # forward