fix(server): fix flash causal (#218)
This commit is contained in:
parent
afc5b999d0
commit
86bca365df
|
@ -453,7 +453,10 @@ class FlashCausalLM(Model):
|
||||||
)
|
)
|
||||||
# Set in batch in case it needs to be used later in concatenate()
|
# Set in batch in case it needs to be used later in concatenate()
|
||||||
batch.past_pad = self.past_pad
|
batch.past_pad = self.past_pad
|
||||||
if len(batch) != 1:
|
if len(batch) == 1:
|
||||||
|
# present is already pre-padded
|
||||||
|
batch.past_key_values = present
|
||||||
|
else:
|
||||||
# Add padding after each sequence
|
# Add padding after each sequence
|
||||||
# This will have the correct shape after the final past_key_values concatenation before the model
|
# This will have the correct shape after the final past_key_values concatenation before the model
|
||||||
# forward
|
# forward
|
||||||
|
|
Loading…
Reference in New Issue