fix(server): fix flash causal (#218)
This commit is contained in:
parent
afc5b999d0
commit
86bca365df
|
@ -453,7 +453,10 @@ class FlashCausalLM(Model):
|
|||
)
|
||||
# Set in batch in case it needs to be used later in concatenate()
|
||||
batch.past_pad = self.past_pad
|
||||
if len(batch) != 1:
|
||||
if len(batch) == 1:
|
||||
# present is already pre-padded
|
||||
batch.past_key_values = present
|
||||
else:
|
||||
# Add padding after each sequence
|
||||
# This will have the correct shape after the final past_key_values concatenation before the model
|
||||
# forward
|
||||
|
|
Loading…
Reference in New Issue