fix(server): cleanup new flash past_key_values logic (#217)
This commit is contained in:
parent
db4cb5e4ed
commit
afc5b999d0
|
@ -594,7 +594,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
# We added padding that now need to slice
|
||||
# We added padding that we now need to slice
|
||||
layer_past_key_values = (
|
||||
past_key_values[i]
|
||||
if slice_past_index is None
|
||||
|
|
|
@ -657,7 +657,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
# We added padding that now need to slice
|
||||
# We added padding that we now need to slice
|
||||
layer_past_key_values = (
|
||||
past_key_values[i]
|
||||
if slice_past_index is None
|
||||
|
|
|
@ -520,7 +520,7 @@ class FlashSantacoderModel(nn.Module):
|
|||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.h):
|
||||
# We added padding that now need to slice
|
||||
# We added padding that we now need to slice
|
||||
layer_past_key_values = (
|
||||
past_key_values[i]
|
||||
if slice_past_index is None
|
||||
|
|
|
@ -404,8 +404,7 @@ class FlashCausalLM(Model):
|
|||
# Shortcut when batch_size == 1
|
||||
if len(batch) == 1:
|
||||
input_ids = batch.input_ids[0].view(-1)
|
||||
# Slice to remove extra padding
|
||||
# past_key_values = batch.past_key_values[:, :batch.input_lengths[0]] if batch.past_key_values is not None else None
|
||||
# No need to slice as flash attention will take care of it with cu_seqlens
|
||||
past_key_values = batch.past_key_values
|
||||
else:
|
||||
# Concatenate tensors
|
||||
|
@ -454,13 +453,7 @@ class FlashCausalLM(Model):
|
|||
)
|
||||
# Set in batch in case it needs to be used later in concatenate()
|
||||
batch.past_pad = self.past_pad
|
||||
if len(batch) == 1:
|
||||
# Preallocate tensor for bs = 1 case
|
||||
batch.past_key_values = torch.nn.functional.pad(
|
||||
present,
|
||||
(0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens),
|
||||
)
|
||||
else:
|
||||
if len(batch) != 1:
|
||||
# Add padding after each sequence
|
||||
# This will have the correct shape after the final past_key_values concatenation before the model
|
||||
# forward
|
||||
|
|
Loading…
Reference in New Issue