From 86bca365df0e39dc4a02f3519aee64e9168c1430 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 21 Apr 2023 19:42:16 +0200 Subject: [PATCH] fix(server): fix flash causal (#218) --- server/text_generation_server/models/flash_causal_lm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7e048b7..c44dd57 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -453,7 +453,10 @@ class FlashCausalLM(Model): ) # Set in batch in case it needs to be used later in concatenate() batch.past_pad = self.past_pad - if len(batch) != 1: + if len(batch) == 1: + # present is already pre-padded + batch.past_key_values = present + else: # Add padding after each sequence # This will have the correct shape after the final past_key_values concatenation before the model # forward