feat(server): reduce memory requirement (#214)
This commit is contained in:
parent
6ded76a4ae
commit
4a7dd4085a
|
@ -175,12 +175,14 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
generations[1].generated_text.generated_tokens
|
||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
# Copy stopping_criterias before filtering
|
||||
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
||||
|
||||
for _ in range(
|
||||
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
stopping_criterias[0].max_new_tokens
|
||||
- stopping_criterias[1].max_new_tokens
|
||||
- 1
|
||||
):
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
|
@ -212,6 +214,15 @@ def test_batch_concatenate(
|
|||
next_batch_1 = default_multi_requests_bloom_batch
|
||||
_, next_batch_1 = default_bloom.generate_token(next_batch_1)
|
||||
|
||||
# Clone past_key_values before concatenating to compare after,
|
||||
# because they are removed from the concatenated batches
|
||||
next_batch_0_past_key_values = [
|
||||
(k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values
|
||||
]
|
||||
next_batch_1_past_key_values = [
|
||||
(k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values
|
||||
]
|
||||
|
||||
next_batch = BloomCausalLMBatch.concatenate([next_batch_0, next_batch_1])
|
||||
|
||||
assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
|
||||
|
@ -246,15 +257,15 @@ def test_batch_concatenate(
|
|||
assert all([p[1].shape == (3, 16, 2, 64) for p in next_batch.past_key_values])
|
||||
|
||||
for i, past in enumerate(next_batch.past_key_values):
|
||||
assert torch.equal(next_batch_0.past_key_values[i][0][:, :, -2:], past[0][0])
|
||||
assert torch.equal(next_batch_0_past_key_values[i][0][:, :, -2:], past[0][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][0][:, :, -1:],
|
||||
next_batch_1_past_key_values[i][0][:, :, -1:],
|
||||
past[0][1:, :, :, -1].reshape(-1, 64, 1),
|
||||
)
|
||||
|
||||
assert torch.equal(next_batch_0.past_key_values[i][1][:, -2:, :], past[1][0])
|
||||
assert torch.equal(next_batch_0_past_key_values[i][1][:, -2:, :], past[1][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][1][:, -1:, :],
|
||||
next_batch_1_past_key_values[i][1][:, -1:, :],
|
||||
past[1][1:, :, -1, :].reshape(-1, 1, 64),
|
||||
)
|
||||
|
||||
|
|
|
@ -173,12 +173,14 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
generations[1].generated_text.generated_tokens
|
||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
# Copy stopping_criterias before filtering
|
||||
stopping_criterias = default_multi_requests_causal_lm_batch.stopping_criterias.copy()
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
||||
|
||||
for _ in range(
|
||||
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
stopping_criterias[0].max_new_tokens
|
||||
- stopping_criterias[1].max_new_tokens
|
||||
- 1
|
||||
):
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
|
@ -209,6 +211,15 @@ def test_batch_concatenate(
|
|||
next_batch_1 = default_multi_requests_causal_lm_batch
|
||||
_, next_batch_1 = default_causal_lm.generate_token(next_batch_1)
|
||||
|
||||
# Clone past_key_values before concatenating to compare after,
|
||||
# because they are removed from the concatenated batches
|
||||
next_batch_0_past_key_values = [
|
||||
(k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values
|
||||
]
|
||||
next_batch_1_past_key_values = [
|
||||
(k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values
|
||||
]
|
||||
|
||||
next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1])
|
||||
|
||||
assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
|
||||
|
@ -244,14 +255,14 @@ def test_batch_concatenate(
|
|||
assert all([p[1].shape == (3, 12, 2, 64) for p in next_batch.past_key_values])
|
||||
|
||||
for i, past in enumerate(next_batch.past_key_values):
|
||||
assert torch.equal(next_batch_0.past_key_values[i][0][0, :, -2:], past[0][0])
|
||||
assert torch.equal(next_batch_0_past_key_values[i][0][0, :, -2:], past[0][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :]
|
||||
next_batch_1_past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :]
|
||||
)
|
||||
|
||||
assert torch.equal(next_batch_0.past_key_values[i][1][0, :, -2:], past[1][0])
|
||||
assert torch.equal(next_batch_0_past_key_values[i][1][0, :, -2:], past[1][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :]
|
||||
next_batch_1_past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :]
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
|
|
|
@ -219,6 +219,19 @@ def test_batch_concatenate(
|
|||
next_batch_1 = default_multi_requests_seq2seq_lm_batch
|
||||
_, next_batch_1 = default_seq2seq_lm.generate_token(next_batch_1)
|
||||
|
||||
# Copy hidden state because it is removed from the concatenated branches
|
||||
next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state
|
||||
next_batch_1_encoder_last_hidden_state = next_batch_1.encoder_last_hidden_state
|
||||
|
||||
# Clone past_key_values before concatenating to compare after,
|
||||
# because they are removed from the concatenated batches
|
||||
next_batch_0_past_key_values = [
|
||||
[t.clone() for t in layer] for layer in next_batch_0.past_key_values
|
||||
]
|
||||
next_batch_1_past_key_values = [
|
||||
[t.clone() for t in layer] for layer in next_batch_1.past_key_values
|
||||
]
|
||||
|
||||
next_batch = Seq2SeqLMBatch.concatenate([next_batch_0, next_batch_1])
|
||||
|
||||
assert next_batch.batch_id == 0
|
||||
|
@ -239,11 +252,11 @@ def test_batch_concatenate(
|
|||
|
||||
assert torch.equal(
|
||||
next_batch.encoder_last_hidden_state[0],
|
||||
next_batch_0.encoder_last_hidden_state[0, -2:],
|
||||
next_batch_0_encoder_last_hidden_state[0, -2:],
|
||||
)
|
||||
assert torch.equal(
|
||||
next_batch.encoder_last_hidden_state[1:],
|
||||
next_batch_1.encoder_last_hidden_state[:, -2:],
|
||||
next_batch_1_encoder_last_hidden_state[:, -2:],
|
||||
)
|
||||
|
||||
assert next_batch.input_lengths == [2, 2, 2]
|
||||
|
@ -275,24 +288,24 @@ def test_batch_concatenate(
|
|||
)
|
||||
|
||||
for i, past in enumerate(next_batch.past_key_values):
|
||||
assert torch.equal(next_batch_0.past_key_values[i][0][0, :, -2:, :], past[0][0])
|
||||
assert torch.equal(next_batch_0_past_key_values[i][0][0, :, -2:, :], past[0][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][0][:, :, -1:, :], past[0][1:, :, -1:, :]
|
||||
next_batch_1_past_key_values[i][0][:, :, -1:, :], past[0][1:, :, -1:, :]
|
||||
)
|
||||
|
||||
assert torch.equal(next_batch_0.past_key_values[i][1][0, :, -2:, :], past[1][0])
|
||||
assert torch.equal(next_batch_0_past_key_values[i][1][0, :, -2:, :], past[1][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][1][:, :, -1:, :], past[1][1:, :, -1:, :]
|
||||
next_batch_1_past_key_values[i][1][:, :, -1:, :], past[1][1:, :, -1:, :]
|
||||
)
|
||||
|
||||
assert torch.equal(next_batch_0.past_key_values[i][2][0, :, -2:, :], past[2][0])
|
||||
assert torch.equal(next_batch_0_past_key_values[i][2][0, :, -2:, :], past[2][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][2][:, :, -2:, :], past[2][1:]
|
||||
next_batch_1_past_key_values[i][2][:, :, -2:, :], past[2][1:]
|
||||
)
|
||||
|
||||
assert torch.equal(next_batch_0.past_key_values[i][3][0, :, -2:, :], past[3][0])
|
||||
assert torch.equal(next_batch_0_past_key_values[i][3][0, :, -2:, :], past[3][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][3][:, :, -2:, :], past[3][1:]
|
||||
next_batch_1_past_key_values[i][3][:, :, -2:, :], past[3][1:]
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
|
|
|
@ -150,6 +150,8 @@ class CausalLMBatch(Batch):
|
|||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
new_padding_right_offset = 0
|
||||
|
||||
for i, r in enumerate(requests):
|
||||
idx = self.requests_idx_mapping[r.id]
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
@ -164,36 +166,57 @@ class CausalLMBatch(Batch):
|
|||
max_input_length = max(max_input_length, request_input_length)
|
||||
|
||||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
stopping_criterias.append(self.stopping_criterias[idx])
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
|
||||
new_padding_right_offset = max(
|
||||
new_padding_right_offset,
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
|
||||
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||
input_ids = self.input_ids[keep_indices]
|
||||
attention_mask = self.attention_mask[keep_indices]
|
||||
position_ids = self.position_ids[keep_indices]
|
||||
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
|
||||
past_key_values = [
|
||||
[t.view(len(self), -1, *t.shape[-2:])[keep_indices] for t in layer]
|
||||
for layer in self.past_key_values
|
||||
self.attention_mask = self.attention_mask[
|
||||
keep_indices,
|
||||
-(self.padding_right_offset + max_input_length):
|
||||
(self.attention_mask.shape[1] - self.padding_right_offset) + new_padding_right_offset,
|
||||
]
|
||||
|
||||
return CausalLMBatch(
|
||||
batch_id=self.batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
all_input_ids=all_input_ids,
|
||||
input_lengths=input_lengths,
|
||||
offsets=offsets,
|
||||
token_offsets=token_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_input_length=max_input_length,
|
||||
padding_right_offset=self.padding_right_offset,
|
||||
keys_head_dim_last=self.keys_head_dim_last,
|
||||
)
|
||||
# Ensure that past_key_values tensors can be updated in-place
|
||||
if type(self.past_key_values[0]) == tuple:
|
||||
self.past_key_values = [list(layer) for layer in self.past_key_values]
|
||||
|
||||
# Update tensors in-place to allow incremental garbage collection
|
||||
past_kv_length = max_input_length - 1
|
||||
for layer in self.past_key_values:
|
||||
past_keys, past_values = layer
|
||||
if len(past_keys.shape) == 3:
|
||||
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
|
||||
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
|
||||
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
|
||||
if self.keys_head_dim_last:
|
||||
layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
|
||||
else:
|
||||
layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
|
||||
del past_keys
|
||||
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
|
||||
del past_values
|
||||
|
||||
self.requests = requests
|
||||
self.requests_idx_mapping = requests_idx_mapping
|
||||
self.input_ids = input_ids
|
||||
self.position_ids = position_ids
|
||||
self.all_input_ids = all_input_ids
|
||||
self.input_lengths = input_lengths
|
||||
self.offsets = offsets
|
||||
self.token_offsets = token_offsets
|
||||
self.next_token_choosers = next_token_choosers
|
||||
self.stopping_criterias = stopping_criterias
|
||||
self.max_input_length = max_input_length
|
||||
self.padding_right_offset = new_padding_right_offset
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
|
@ -285,62 +308,88 @@ class CausalLMBatch(Batch):
|
|||
position_ids = batch.position_ids.new_empty((total_batch_size, 1))
|
||||
position_ids[start_index:end_index] = batch.position_ids
|
||||
|
||||
for j, past in enumerate(batch.past_key_values):
|
||||
past_keys, past_values = past
|
||||
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
|
||||
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
|
||||
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
|
||||
# And ensure that we can update tensors in-place
|
||||
if type(batch.past_key_values[0]) == tuple:
|
||||
batch.past_key_values = [
|
||||
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values
|
||||
]
|
||||
elif batch.past_key_values[0][0].shape == 3:
|
||||
for layer in batch.past_key_values:
|
||||
for k, t in enumerate(layer):
|
||||
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
|
||||
|
||||
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
|
||||
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
|
||||
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
|
||||
past_keys = past_keys.view(len(batch), -1, *past_keys.shape[-2:])
|
||||
past_values = past_values.view(len(batch), -1, *past_values.shape[-2:])
|
||||
start_index = end_index
|
||||
|
||||
_, num_heads, padded_sequence_length, head_dim = past_values.shape
|
||||
first_past_kvs = batches[0].past_key_values
|
||||
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
|
||||
|
||||
padded_past_values_shape = (
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
max_input_length - 1,
|
||||
head_dim,
|
||||
)
|
||||
padded_past_values_shape = (
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
max_input_length - 1,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
if batches[0].keys_head_dim_last:
|
||||
padded_past_keys_shape = padded_past_values_shape
|
||||
else:
|
||||
# seq_length is last for BLOOM
|
||||
padded_past_keys_shape = (
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
head_dim,
|
||||
max_input_length - 1,
|
||||
)
|
||||
|
||||
# Iterate over attention layers
|
||||
# Concatenate past key values layer by layer to allow incremental garbage collection
|
||||
for j in range(len(first_past_kvs)):
|
||||
padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
|
||||
start_index = 0
|
||||
for batch in batches:
|
||||
past_keys = batch.past_key_values[j][0]
|
||||
# Clear reference to the original tensor
|
||||
batch.past_key_values[j][0] = None
|
||||
|
||||
# Slicing end index for this batch
|
||||
end_index = start_index + len(batch)
|
||||
# We slice the keys to remove the padding from previous batches
|
||||
past_seq_len = batch.max_input_length - 1
|
||||
if batch.keys_head_dim_last:
|
||||
padded_past_keys_shape = padded_past_values_shape
|
||||
padded_past_keys[
|
||||
start_index:end_index, :, -past_seq_len:, :
|
||||
] = past_keys[:, :, -past_seq_len:, :]
|
||||
else:
|
||||
# seq_length is last for BLOOM
|
||||
padded_past_keys_shape = (
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
head_dim,
|
||||
max_input_length - 1,
|
||||
)
|
||||
# BLOOM case
|
||||
padded_past_keys[
|
||||
start_index:end_index, :, :, -past_seq_len:
|
||||
] = past_keys[:, :, :, -past_seq_len:]
|
||||
del past_keys
|
||||
|
||||
# This will run only once per layer
|
||||
if j == len(past_key_values):
|
||||
padded_past_keys = past_keys.new_zeros(padded_past_keys_shape)
|
||||
padded_past_values = past_values.new_zeros(padded_past_values_shape)
|
||||
past_key_values.append((padded_past_keys, padded_past_values))
|
||||
start_index = end_index
|
||||
|
||||
# We slice the past keys and values to remove the padding from previous batches
|
||||
if batch.keys_head_dim_last:
|
||||
past_key_values[j][0][
|
||||
start_index:end_index,
|
||||
:,
|
||||
-(batch.max_input_length - 1) :,
|
||||
:,
|
||||
] = past_keys[:, :, -(batch.max_input_length - 1) :, :]
|
||||
else:
|
||||
past_key_values[j][0][
|
||||
start_index:end_index,
|
||||
:,
|
||||
:,
|
||||
-(batch.max_input_length - 1) :,
|
||||
] = past_keys[:, :, :, -(batch.max_input_length - 1) :]
|
||||
padded_past_values = first_past_kvs[j][1].new_zeros(padded_past_values_shape)
|
||||
start_index = 0
|
||||
for batch in batches:
|
||||
past_values = batch.past_key_values[j][1]
|
||||
# Clear reference to the original tensor
|
||||
batch.past_key_values[j][1] = None
|
||||
|
||||
past_key_values[j][1][
|
||||
start_index:end_index, :, -(batch.max_input_length - 1) :, :
|
||||
] = past_values[:, :, -(batch.max_input_length - 1) :, :]
|
||||
# Slicing end index for this batch
|
||||
end_index = start_index + len(batch)
|
||||
# We slice the past values to remove the padding from previous batches
|
||||
past_seq_len = batch.max_input_length - 1
|
||||
padded_past_values[
|
||||
start_index:end_index, :, -past_seq_len:, :
|
||||
] = past_values[:, :, -past_seq_len:, :]
|
||||
del past_values
|
||||
|
||||
start_index += len(batch)
|
||||
start_index = end_index
|
||||
|
||||
past_key_values.append([padded_past_keys, padded_past_values])
|
||||
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
|
|
|
@ -25,7 +25,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
requests_idx_mapping: Dict[int, int]
|
||||
|
||||
# Encoder values
|
||||
input_ids: torch.Tensor
|
||||
input_ids: Optional[torch.Tensor]
|
||||
attention_mask: torch.Tensor
|
||||
|
||||
# Decoder values
|
||||
|
@ -164,6 +164,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
|
||||
max_input_length = 0
|
||||
max_decoder_input_length = 0
|
||||
padding_right_offset = 0
|
||||
|
||||
for i, r in enumerate(requests):
|
||||
idx = self.requests_idx_mapping[r.id]
|
||||
|
@ -184,45 +185,53 @@ class Seq2SeqLMBatch(Batch):
|
|||
max_decoder_input_length = max(
|
||||
max_decoder_input_length, request_decoder_input_length
|
||||
)
|
||||
padding_right_offset = max(
|
||||
padding_right_offset,
|
||||
self.stopping_criterias[idx].max_new_tokens - self.stopping_criterias[idx].current_tokens
|
||||
)
|
||||
|
||||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
stopping_criterias.append(self.stopping_criterias[idx])
|
||||
|
||||
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||
decoder_input_ids = self.decoder_input_ids[keep_indices]
|
||||
attention_mask = self.attention_mask[keep_indices]
|
||||
self.decoder_input_ids = self.decoder_input_ids[keep_indices]
|
||||
self.attention_mask = self.attention_mask[keep_indices, -max_input_length:]
|
||||
if self.decoder_attention_mask is not None:
|
||||
decoder_attention_mask = self.decoder_attention_mask[keep_indices]
|
||||
else:
|
||||
decoder_attention_mask = None
|
||||
self.decoder_attention_mask = self.decoder_attention_mask[
|
||||
keep_indices,
|
||||
-(self.padding_right_offset + max_decoder_input_length):
|
||||
(self.decoder_attention_mask.shape[1] - self.padding_right_offset) + padding_right_offset,
|
||||
]
|
||||
|
||||
encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices]
|
||||
self.encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices, -max_input_length:]
|
||||
|
||||
past_key_values = [
|
||||
[t[keep_indices] for t in layer] for layer in self.past_key_values
|
||||
]
|
||||
# Ensure that past_key_values tensors can be updated in-place
|
||||
if type(self.past_key_values[0]) == tuple:
|
||||
self.past_key_values = [[t for t in layer] for layer in self.past_key_values]
|
||||
|
||||
decoder_past_seq_len = max_decoder_input_length - 1
|
||||
for layer in self.past_key_values:
|
||||
layer[0] = layer[0][keep_indices, :, -decoder_past_seq_len:]
|
||||
layer[1] = layer[1][keep_indices, :, -decoder_past_seq_len:]
|
||||
layer[2] = layer[2][keep_indices, :, -max_input_length:]
|
||||
layer[3] = layer[3][keep_indices, :, -max_input_length:]
|
||||
|
||||
self.requests = requests
|
||||
self.requests_idx_mapping = requests_idx_mapping
|
||||
self.input_ids = None
|
||||
self.all_decoder_input_ids = all_decoder_input_ids
|
||||
self.input_lengths = input_lengths
|
||||
self.decoder_input_lengths = decoder_input_lengths
|
||||
self.offsets = offsets
|
||||
self.token_offsets = token_offsets
|
||||
self.next_token_choosers = next_token_choosers
|
||||
self.stopping_criterias = stopping_criterias
|
||||
self.max_input_length = max_input_length
|
||||
self.max_decoder_input_length = max_decoder_input_length
|
||||
self.padding_right_offset = padding_right_offset
|
||||
|
||||
return self
|
||||
|
||||
return Seq2SeqLMBatch(
|
||||
batch_id=self.batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=None,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
all_decoder_input_ids=all_decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_last_hidden_state=encoder_last_hidden_state,
|
||||
past_key_values=past_key_values,
|
||||
input_lengths=input_lengths,
|
||||
decoder_input_lengths=decoder_input_lengths,
|
||||
offsets=offsets,
|
||||
token_offsets=token_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_input_length=max_input_length,
|
||||
max_decoder_input_length=max_decoder_input_length,
|
||||
padding_right_offset=self.padding_right_offset,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
|
@ -350,58 +359,78 @@ class Seq2SeqLMBatch(Batch):
|
|||
encoder_last_hidden_state[
|
||||
start_index:end_index, -batch.max_input_length :, :
|
||||
] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
|
||||
batch.encoder_last_hidden_state = None
|
||||
|
||||
# Iterate over attention layers
|
||||
for j, past in enumerate(batch.past_key_values):
|
||||
_, num_heads, _, head_dim = past[0].shape
|
||||
# Ensure that we can update tensors in-place
|
||||
if type(batch.past_key_values[0]) == tuple:
|
||||
batch.past_key_values = [[t for t in layer] for layer in batch.past_key_values]
|
||||
|
||||
# This will run only once per layer
|
||||
if j == len(past_key_values):
|
||||
past_key_values.append([])
|
||||
start_index = end_index
|
||||
|
||||
# Decoder past
|
||||
for k, t in enumerate(past[:2]):
|
||||
padded_t_shape = (
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
(max_decoder_input_length - 1),
|
||||
head_dim,
|
||||
)
|
||||
# Determine shapes for new past kv tensors
|
||||
first_past_kvs = batches[0].past_key_values
|
||||
_, num_heads, _, head_dim = first_past_kvs[0][0].shape
|
||||
|
||||
# Initialize tensors
|
||||
# This will run only once per layer and per past tensor
|
||||
if k == len(past_key_values[j]):
|
||||
past_key_values[j].append(t.new_zeros(padded_t_shape))
|
||||
padded_dec_t_shape = (
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
(max_decoder_input_length - 1),
|
||||
head_dim,
|
||||
)
|
||||
|
||||
padded_enc_t_shape = (
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
max_input_length,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
# Iterate over attention layers
|
||||
for j in range(len(first_past_kvs)):
|
||||
past_key_values.append([])
|
||||
|
||||
# Decoder past
|
||||
for k in range(0, 2):
|
||||
# Initialize tensors
|
||||
padded_past_values = first_past_kvs[j][k].new_zeros(padded_dec_t_shape)
|
||||
past_key_values[j].append(padded_past_values)
|
||||
|
||||
start_index = 0
|
||||
for batch in batches:
|
||||
t = batch.past_key_values[j][k]
|
||||
# Clear reference to the original tensor
|
||||
batch.past_key_values[j][k] = None
|
||||
# Slicing end index for this batch
|
||||
end_index = start_index + len(batch)
|
||||
# We slice the past keys and values to remove the padding from previous batches
|
||||
past_key_values[j][k][
|
||||
start_index:end_index,
|
||||
:,
|
||||
-(batch.max_decoder_input_length - 1) :,
|
||||
:,
|
||||
] = t[:, :, -(batch.max_decoder_input_length - 1) :, :]
|
||||
past_seq_len = batch.max_decoder_input_length - 1
|
||||
padded_past_values[
|
||||
start_index:end_index, :, -past_seq_len:, :
|
||||
] = t[:, :, -past_seq_len:, :]
|
||||
del t
|
||||
|
||||
# encoder past
|
||||
for k, t in enumerate(past[2:]):
|
||||
padded_t_shape = (
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
max_input_length,
|
||||
head_dim,
|
||||
)
|
||||
start_index = end_index
|
||||
|
||||
idx = k + 2
|
||||
# Encoder past
|
||||
for k in range(2, 4):
|
||||
# Initialize tensors
|
||||
padded_past_values = first_past_kvs[j][k].new_zeros(padded_enc_t_shape)
|
||||
past_key_values[j].append(padded_past_values)
|
||||
|
||||
# Initialize tensors
|
||||
# This will run only once per layer and per past tensor
|
||||
if idx == len(past_key_values[j]):
|
||||
past_key_values[j].append(t.new_zeros(padded_t_shape))
|
||||
start_index = 0
|
||||
for batch in batches:
|
||||
t = batch.past_key_values[j][k]
|
||||
# Clear reference to the original tensor
|
||||
batch.past_key_values[j][k] = None
|
||||
# Slicing end index for this batch
|
||||
end_index = start_index + len(batch)
|
||||
# We slice the past keys and values to remove the padding from previous batches
|
||||
padded_past_values[
|
||||
start_index:end_index, :, -batch.max_input_length:, :
|
||||
] = t[:, :, -batch.max_input_length:, :]
|
||||
del t
|
||||
|
||||
past_key_values[j][idx][
|
||||
start_index:end_index, :, -batch.max_input_length :, :
|
||||
] = t[:, :, -batch.max_input_length :, :]
|
||||
|
||||
start_index += len(batch)
|
||||
start_index = end_index
|
||||
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
|
|
Loading…
Reference in New Issue