fix(server): fix flash causal (#218)

This commit is contained in:
OlivierDehaene 2023-04-21 19:42:16 +02:00 committed by GitHub
parent afc5b999d0
commit 86bca365df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 1 deletions

View File

@ -453,7 +453,10 @@ class FlashCausalLM(Model):
) )
# Set in batch in case it needs to be used later in concatenate() # Set in batch in case it needs to be used later in concatenate()
batch.past_pad = self.past_pad batch.past_pad = self.past_pad
if len(batch) != 1: if len(batch) == 1:
# present is already pre-padded
batch.past_key_values = present
else:
# Add padding after each sequence # Add padding after each sequence
# This will have the correct shape after the final past_key_values concatenation before the model # This will have the correct shape after the final past_key_values concatenation before the model
# forward # forward