diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index de0ef57b..47d701eb 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -175,12 +175,14 @@ def test_causal_lm_generate_token_completion_multi( generations[1].generated_text.generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) + # Copy stopping_criterias before filtering + stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy() next_batch = next_batch.filter([next_batch.requests[0]]) for _ in range( - default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens - - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens + stopping_criterias[0].max_new_tokens + - stopping_criterias[1].max_new_tokens - 1 ): generations, next_batch = default_bloom.generate_token(next_batch) @@ -212,6 +214,15 @@ def test_batch_concatenate( next_batch_1 = default_multi_requests_bloom_batch _, next_batch_1 = default_bloom.generate_token(next_batch_1) + # Clone past_key_values before concatenating to compare after, + # because they are removed from the concatenated batches + next_batch_0_past_key_values = [ + (k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values + ] + next_batch_1_past_key_values = [ + (k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values + ] + next_batch = BloomCausalLMBatch.concatenate([next_batch_0, next_batch_1]) assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0]) @@ -246,15 +257,15 @@ def test_batch_concatenate( assert all([p[1].shape == (3, 16, 2, 64) for p in next_batch.past_key_values]) for i, past in enumerate(next_batch.past_key_values): - assert torch.equal(next_batch_0.past_key_values[i][0][:, :, -2:], past[0][0]) + assert torch.equal(next_batch_0_past_key_values[i][0][:, :, -2:], past[0][0]) assert torch.equal( - next_batch_1.past_key_values[i][0][:, :, -1:], + next_batch_1_past_key_values[i][0][:, :, -1:], past[0][1:, :, :, -1].reshape(-1, 64, 1), ) - assert torch.equal(next_batch_0.past_key_values[i][1][:, -2:, :], past[1][0]) + assert torch.equal(next_batch_0_past_key_values[i][1][:, -2:, :], past[1][0]) assert torch.equal( - next_batch_1.past_key_values[i][1][:, -1:, :], + next_batch_1_past_key_values[i][1][:, -1:, :], past[1][1:, :, -1, :].reshape(-1, 1, 64), ) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index ad79a4ca..03d3ef9b 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -173,12 +173,14 @@ def test_causal_lm_generate_token_completion_multi( generations[1].generated_text.generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) + # Copy stopping_criterias before filtering + stopping_criterias = default_multi_requests_causal_lm_batch.stopping_criterias.copy() next_batch = next_batch.filter([next_batch.requests[0]]) for _ in range( - default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens - - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + stopping_criterias[0].max_new_tokens + - stopping_criterias[1].max_new_tokens - 1 ): generations, next_batch = default_causal_lm.generate_token(next_batch) @@ -209,6 +211,15 @@ def test_batch_concatenate( next_batch_1 = default_multi_requests_causal_lm_batch _, next_batch_1 = default_causal_lm.generate_token(next_batch_1) + # Clone past_key_values before concatenating to compare after, + # because they are removed from the concatenated batches + next_batch_0_past_key_values = [ + (k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values + ] + next_batch_1_past_key_values = [ + (k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values + ] + next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1]) assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0]) @@ -244,14 +255,14 @@ def test_batch_concatenate( assert all([p[1].shape == (3, 12, 2, 64) for p in next_batch.past_key_values]) for i, past in enumerate(next_batch.past_key_values): - assert torch.equal(next_batch_0.past_key_values[i][0][0, :, -2:], past[0][0]) + assert torch.equal(next_batch_0_past_key_values[i][0][0, :, -2:], past[0][0]) assert torch.equal( - next_batch_1.past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :] + next_batch_1_past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :] ) - assert torch.equal(next_batch_0.past_key_values[i][1][0, :, -2:], past[1][0]) + assert torch.equal(next_batch_0_past_key_values[i][1][0, :, -2:], past[1][0]) assert torch.equal( - next_batch_1.past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :] + next_batch_1_past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :] ) for _ in range( diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 79c9e936..65dafa50 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -219,6 +219,19 @@ def test_batch_concatenate( next_batch_1 = default_multi_requests_seq2seq_lm_batch _, next_batch_1 = default_seq2seq_lm.generate_token(next_batch_1) + # Copy hidden state because it is removed from the concatenated branches + next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state + next_batch_1_encoder_last_hidden_state = next_batch_1.encoder_last_hidden_state + + # Clone past_key_values before concatenating to compare after, + # because they are removed from the concatenated batches + next_batch_0_past_key_values = [ + [t.clone() for t in layer] for layer in next_batch_0.past_key_values + ] + next_batch_1_past_key_values = [ + [t.clone() for t in layer] for layer in next_batch_1.past_key_values + ] + next_batch = Seq2SeqLMBatch.concatenate([next_batch_0, next_batch_1]) assert next_batch.batch_id == 0 @@ -239,11 +252,11 @@ def test_batch_concatenate( assert torch.equal( next_batch.encoder_last_hidden_state[0], - next_batch_0.encoder_last_hidden_state[0, -2:], + next_batch_0_encoder_last_hidden_state[0, -2:], ) assert torch.equal( next_batch.encoder_last_hidden_state[1:], - next_batch_1.encoder_last_hidden_state[:, -2:], + next_batch_1_encoder_last_hidden_state[:, -2:], ) assert next_batch.input_lengths == [2, 2, 2] @@ -275,24 +288,24 @@ def test_batch_concatenate( ) for i, past in enumerate(next_batch.past_key_values): - assert torch.equal(next_batch_0.past_key_values[i][0][0, :, -2:, :], past[0][0]) + assert torch.equal(next_batch_0_past_key_values[i][0][0, :, -2:, :], past[0][0]) assert torch.equal( - next_batch_1.past_key_values[i][0][:, :, -1:, :], past[0][1:, :, -1:, :] + next_batch_1_past_key_values[i][0][:, :, -1:, :], past[0][1:, :, -1:, :] ) - assert torch.equal(next_batch_0.past_key_values[i][1][0, :, -2:, :], past[1][0]) + assert torch.equal(next_batch_0_past_key_values[i][1][0, :, -2:, :], past[1][0]) assert torch.equal( - next_batch_1.past_key_values[i][1][:, :, -1:, :], past[1][1:, :, -1:, :] + next_batch_1_past_key_values[i][1][:, :, -1:, :], past[1][1:, :, -1:, :] ) - assert torch.equal(next_batch_0.past_key_values[i][2][0, :, -2:, :], past[2][0]) + assert torch.equal(next_batch_0_past_key_values[i][2][0, :, -2:, :], past[2][0]) assert torch.equal( - next_batch_1.past_key_values[i][2][:, :, -2:, :], past[2][1:] + next_batch_1_past_key_values[i][2][:, :, -2:, :], past[2][1:] ) - assert torch.equal(next_batch_0.past_key_values[i][3][0, :, -2:, :], past[3][0]) + assert torch.equal(next_batch_0_past_key_values[i][3][0, :, -2:, :], past[3][0]) assert torch.equal( - next_batch_1.past_key_values[i][3][:, :, -2:, :], past[3][1:] + next_batch_1_past_key_values[i][3][:, :, -2:, :], past[3][1:] ) for _ in range(3): diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 98313253..1db5abce 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -150,6 +150,8 @@ class CausalLMBatch(Batch): next_token_choosers = [] stopping_criterias = [] + new_padding_right_offset = 0 + for i, r in enumerate(requests): idx = self.requests_idx_mapping[r.id] requests_idx_mapping[r.id] = i @@ -164,36 +166,57 @@ class CausalLMBatch(Batch): max_input_length = max(max_input_length, request_input_length) next_token_choosers.append(self.next_token_choosers[idx]) - stopping_criterias.append(self.stopping_criterias[idx]) + stopping_criteria = self.stopping_criterias[idx] + stopping_criterias.append(stopping_criteria) + + new_padding_right_offset = max( + new_padding_right_offset, + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) # Apply indices to input_ids, attention mask, past key values and other items that need to be cached input_ids = self.input_ids[keep_indices] - attention_mask = self.attention_mask[keep_indices] position_ids = self.position_ids[keep_indices] - # Force past to be of dim [self_size, num_heads, ...] for easy indexing - past_key_values = [ - [t.view(len(self), -1, *t.shape[-2:])[keep_indices] for t in layer] - for layer in self.past_key_values + self.attention_mask = self.attention_mask[ + keep_indices, + -(self.padding_right_offset + max_input_length): + (self.attention_mask.shape[1] - self.padding_right_offset) + new_padding_right_offset, ] - return CausalLMBatch( - batch_id=self.batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - all_input_ids=all_input_ids, - input_lengths=input_lengths, - offsets=offsets, - token_offsets=token_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - max_input_length=max_input_length, - padding_right_offset=self.padding_right_offset, - keys_head_dim_last=self.keys_head_dim_last, - ) + # Ensure that past_key_values tensors can be updated in-place + if type(self.past_key_values[0]) == tuple: + self.past_key_values = [list(layer) for layer in self.past_key_values] + + # Update tensors in-place to allow incremental garbage collection + past_kv_length = max_input_length - 1 + for layer in self.past_key_values: + past_keys, past_values = layer + if len(past_keys.shape) == 3: + # Force past to be of dim [self_size, num_heads, ...] for easy indexing + past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) + past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) + if self.keys_head_dim_last: + layer[0] = past_keys[keep_indices, :, -past_kv_length:, :] + else: + layer[0] = past_keys[keep_indices, :, :, -past_kv_length:] + del past_keys + layer[1] = past_values[keep_indices, :, -past_kv_length:, :] + del past_values + + self.requests = requests + self.requests_idx_mapping = requests_idx_mapping + self.input_ids = input_ids + self.position_ids = position_ids + self.all_input_ids = all_input_ids + self.input_lengths = input_lengths + self.offsets = offsets + self.token_offsets = token_offsets + self.next_token_choosers = next_token_choosers + self.stopping_criterias = stopping_criterias + self.max_input_length = max_input_length + self.padding_right_offset = new_padding_right_offset + + return self @classmethod @tracer.start_as_current_span("concatenate") @@ -285,62 +308,88 @@ class CausalLMBatch(Batch): position_ids = batch.position_ids.new_empty((total_batch_size, 1)) position_ids[start_index:end_index] = batch.position_ids - for j, past in enumerate(batch.past_key_values): - past_keys, past_values = past + # Shenanigans to get dimensions because BLOOM outputs a past with a different shape + # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] + # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] + # And ensure that we can update tensors in-place + if type(batch.past_key_values[0]) == tuple: + batch.past_key_values = [ + [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values + ] + elif batch.past_key_values[0][0].shape == 3: + for layer in batch.past_key_values: + for k, t in enumerate(layer): + layer[k] = t.view(len(batch), -1, *t.shape[-2:]) - # Shenanigans to get dimensions because BLOOM outputs a past with a different shape - # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] - # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] - past_keys = past_keys.view(len(batch), -1, *past_keys.shape[-2:]) - past_values = past_values.view(len(batch), -1, *past_values.shape[-2:]) + start_index = end_index - _, num_heads, padded_sequence_length, head_dim = past_values.shape + first_past_kvs = batches[0].past_key_values + _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape - padded_past_values_shape = ( - total_batch_size, - num_heads, - max_input_length - 1, - head_dim, - ) + padded_past_values_shape = ( + total_batch_size, + num_heads, + max_input_length - 1, + head_dim, + ) + if batches[0].keys_head_dim_last: + padded_past_keys_shape = padded_past_values_shape + else: + # seq_length is last for BLOOM + padded_past_keys_shape = ( + total_batch_size, + num_heads, + head_dim, + max_input_length - 1, + ) + + # Iterate over attention layers + # Concatenate past key values layer by layer to allow incremental garbage collection + for j in range(len(first_past_kvs)): + padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) + start_index = 0 + for batch in batches: + past_keys = batch.past_key_values[j][0] + # Clear reference to the original tensor + batch.past_key_values[j][0] = None + + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the keys to remove the padding from previous batches + past_seq_len = batch.max_input_length - 1 if batch.keys_head_dim_last: - padded_past_keys_shape = padded_past_values_shape + padded_past_keys[ + start_index:end_index, :, -past_seq_len:, : + ] = past_keys[:, :, -past_seq_len:, :] else: - # seq_length is last for BLOOM - padded_past_keys_shape = ( - total_batch_size, - num_heads, - head_dim, - max_input_length - 1, - ) + # BLOOM case + padded_past_keys[ + start_index:end_index, :, :, -past_seq_len: + ] = past_keys[:, :, :, -past_seq_len:] + del past_keys - # This will run only once per layer - if j == len(past_key_values): - padded_past_keys = past_keys.new_zeros(padded_past_keys_shape) - padded_past_values = past_values.new_zeros(padded_past_values_shape) - past_key_values.append((padded_past_keys, padded_past_values)) + start_index = end_index - # We slice the past keys and values to remove the padding from previous batches - if batch.keys_head_dim_last: - past_key_values[j][0][ - start_index:end_index, - :, - -(batch.max_input_length - 1) :, - :, - ] = past_keys[:, :, -(batch.max_input_length - 1) :, :] - else: - past_key_values[j][0][ - start_index:end_index, - :, - :, - -(batch.max_input_length - 1) :, - ] = past_keys[:, :, :, -(batch.max_input_length - 1) :] + padded_past_values = first_past_kvs[j][1].new_zeros(padded_past_values_shape) + start_index = 0 + for batch in batches: + past_values = batch.past_key_values[j][1] + # Clear reference to the original tensor + batch.past_key_values[j][1] = None - past_key_values[j][1][ - start_index:end_index, :, -(batch.max_input_length - 1) :, : - ] = past_values[:, :, -(batch.max_input_length - 1) :, :] + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the past values to remove the padding from previous batches + past_seq_len = batch.max_input_length - 1 + padded_past_values[ + start_index:end_index, :, -past_seq_len:, : + ] = past_values[:, :, -past_seq_len:, :] + del past_values - start_index += len(batch) + start_index = end_index + + past_key_values.append([padded_past_keys, padded_past_values]) return cls( batch_id=batches[0].batch_id, diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index aa452c70..2252fcfc 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -25,7 +25,7 @@ class Seq2SeqLMBatch(Batch): requests_idx_mapping: Dict[int, int] # Encoder values - input_ids: torch.Tensor + input_ids: Optional[torch.Tensor] attention_mask: torch.Tensor # Decoder values @@ -164,6 +164,7 @@ class Seq2SeqLMBatch(Batch): max_input_length = 0 max_decoder_input_length = 0 + padding_right_offset = 0 for i, r in enumerate(requests): idx = self.requests_idx_mapping[r.id] @@ -184,45 +185,53 @@ class Seq2SeqLMBatch(Batch): max_decoder_input_length = max( max_decoder_input_length, request_decoder_input_length ) + padding_right_offset = max( + padding_right_offset, + self.stopping_criterias[idx].max_new_tokens - self.stopping_criterias[idx].current_tokens + ) next_token_choosers.append(self.next_token_choosers[idx]) stopping_criterias.append(self.stopping_criterias[idx]) # Apply indices to input_ids, attention mask, past key values and other items that need to be cached - decoder_input_ids = self.decoder_input_ids[keep_indices] - attention_mask = self.attention_mask[keep_indices] + self.decoder_input_ids = self.decoder_input_ids[keep_indices] + self.attention_mask = self.attention_mask[keep_indices, -max_input_length:] if self.decoder_attention_mask is not None: - decoder_attention_mask = self.decoder_attention_mask[keep_indices] - else: - decoder_attention_mask = None + self.decoder_attention_mask = self.decoder_attention_mask[ + keep_indices, + -(self.padding_right_offset + max_decoder_input_length): + (self.decoder_attention_mask.shape[1] - self.padding_right_offset) + padding_right_offset, + ] - encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices] + self.encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices, -max_input_length:] - past_key_values = [ - [t[keep_indices] for t in layer] for layer in self.past_key_values - ] + # Ensure that past_key_values tensors can be updated in-place + if type(self.past_key_values[0]) == tuple: + self.past_key_values = [[t for t in layer] for layer in self.past_key_values] + + decoder_past_seq_len = max_decoder_input_length - 1 + for layer in self.past_key_values: + layer[0] = layer[0][keep_indices, :, -decoder_past_seq_len:] + layer[1] = layer[1][keep_indices, :, -decoder_past_seq_len:] + layer[2] = layer[2][keep_indices, :, -max_input_length:] + layer[3] = layer[3][keep_indices, :, -max_input_length:] + + self.requests = requests + self.requests_idx_mapping = requests_idx_mapping + self.input_ids = None + self.all_decoder_input_ids = all_decoder_input_ids + self.input_lengths = input_lengths + self.decoder_input_lengths = decoder_input_lengths + self.offsets = offsets + self.token_offsets = token_offsets + self.next_token_choosers = next_token_choosers + self.stopping_criterias = stopping_criterias + self.max_input_length = max_input_length + self.max_decoder_input_length = max_decoder_input_length + self.padding_right_offset = padding_right_offset + + return self - return Seq2SeqLMBatch( - batch_id=self.batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=None, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - all_decoder_input_ids=all_decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - encoder_last_hidden_state=encoder_last_hidden_state, - past_key_values=past_key_values, - input_lengths=input_lengths, - decoder_input_lengths=decoder_input_lengths, - offsets=offsets, - token_offsets=token_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - max_input_length=max_input_length, - max_decoder_input_length=max_decoder_input_length, - padding_right_offset=self.padding_right_offset, - ) @classmethod @tracer.start_as_current_span("concatenate") @@ -350,58 +359,78 @@ class Seq2SeqLMBatch(Batch): encoder_last_hidden_state[ start_index:end_index, -batch.max_input_length :, : ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :] + batch.encoder_last_hidden_state = None - # Iterate over attention layers - for j, past in enumerate(batch.past_key_values): - _, num_heads, _, head_dim = past[0].shape + # Ensure that we can update tensors in-place + if type(batch.past_key_values[0]) == tuple: + batch.past_key_values = [[t for t in layer] for layer in batch.past_key_values] - # This will run only once per layer - if j == len(past_key_values): - past_key_values.append([]) + start_index = end_index - # Decoder past - for k, t in enumerate(past[:2]): - padded_t_shape = ( - total_batch_size, - num_heads, - (max_decoder_input_length - 1), - head_dim, - ) + # Determine shapes for new past kv tensors + first_past_kvs = batches[0].past_key_values + _, num_heads, _, head_dim = first_past_kvs[0][0].shape - # Initialize tensors - # This will run only once per layer and per past tensor - if k == len(past_key_values[j]): - past_key_values[j].append(t.new_zeros(padded_t_shape)) + padded_dec_t_shape = ( + total_batch_size, + num_heads, + (max_decoder_input_length - 1), + head_dim, + ) + padded_enc_t_shape = ( + total_batch_size, + num_heads, + max_input_length, + head_dim, + ) + + # Iterate over attention layers + for j in range(len(first_past_kvs)): + past_key_values.append([]) + + # Decoder past + for k in range(0, 2): + # Initialize tensors + padded_past_values = first_past_kvs[j][k].new_zeros(padded_dec_t_shape) + past_key_values[j].append(padded_past_values) + + start_index = 0 + for batch in batches: + t = batch.past_key_values[j][k] + # Clear reference to the original tensor + batch.past_key_values[j][k] = None + # Slicing end index for this batch + end_index = start_index + len(batch) # We slice the past keys and values to remove the padding from previous batches - past_key_values[j][k][ - start_index:end_index, - :, - -(batch.max_decoder_input_length - 1) :, - :, - ] = t[:, :, -(batch.max_decoder_input_length - 1) :, :] + past_seq_len = batch.max_decoder_input_length - 1 + padded_past_values[ + start_index:end_index, :, -past_seq_len:, : + ] = t[:, :, -past_seq_len:, :] + del t - # encoder past - for k, t in enumerate(past[2:]): - padded_t_shape = ( - total_batch_size, - num_heads, - max_input_length, - head_dim, - ) + start_index = end_index - idx = k + 2 + # Encoder past + for k in range(2, 4): + # Initialize tensors + padded_past_values = first_past_kvs[j][k].new_zeros(padded_enc_t_shape) + past_key_values[j].append(padded_past_values) - # Initialize tensors - # This will run only once per layer and per past tensor - if idx == len(past_key_values[j]): - past_key_values[j].append(t.new_zeros(padded_t_shape)) + start_index = 0 + for batch in batches: + t = batch.past_key_values[j][k] + # Clear reference to the original tensor + batch.past_key_values[j][k] = None + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the past keys and values to remove the padding from previous batches + padded_past_values[ + start_index:end_index, :, -batch.max_input_length:, : + ] = t[:, :, -batch.max_input_length:, :] + del t - past_key_values[j][idx][ - start_index:end_index, :, -batch.max_input_length :, : - ] = t[:, :, -batch.max_input_length :, :] - - start_index += len(batch) + start_index = end_index return cls( batch_id=batches[0].batch_id,