fix(server): cleanup new flash past_key_values logic (#217)

This commit is contained in:
OlivierDehaene 2023-04-21 16:19:04 +02:00 committed by GitHub
parent db4cb5e4ed
commit afc5b999d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 5 additions and 12 deletions

View File

@ -594,7 +594,7 @@ class FlashLlamaModel(torch.nn.Module):
residual = None
for i, layer in enumerate(self.layers):
# We added padding that now need to slice
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None

View File

@ -657,7 +657,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
residual = None
for i, layer in enumerate(self.layers):
# We added padding that now need to slice
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None

View File

@ -520,7 +520,7 @@ class FlashSantacoderModel(nn.Module):
residual = None
for i, layer in enumerate(self.h):
# We added padding that now need to slice
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None

View File

@ -404,8 +404,7 @@ class FlashCausalLM(Model):
# Shortcut when batch_size == 1
if len(batch) == 1:
input_ids = batch.input_ids[0].view(-1)
# Slice to remove extra padding
# past_key_values = batch.past_key_values[:, :batch.input_lengths[0]] if batch.past_key_values is not None else None
# No need to slice as flash attention will take care of it with cu_seqlens
past_key_values = batch.past_key_values
else:
# Concatenate tensors
@ -454,13 +453,7 @@ class FlashCausalLM(Model):
)
# Set in batch in case it needs to be used later in concatenate()
batch.past_pad = self.past_pad
if len(batch) == 1:
# Preallocate tensor for bs = 1 case
batch.past_key_values = torch.nn.functional.pad(
present,
(0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens),
)
else:
if len(batch) != 1:
# Add padding after each sequence
# This will have the correct shape after the final past_key_values concatenation before the model
# forward