feat(server): pre-allocate max attention mask (#75)

This commit is contained in:
OlivierDehaene 2023-02-24 12:49:21 +01:00 committed by GitHub
parent 78063c0569
commit 44ce098c10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 148 additions and 114 deletions

View File

@ -65,8 +65,8 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
assert batch.input_ids[0][-1] == 10264 assert batch.input_ids[0][-1] == 10264
assert torch.all(batch.input_ids[0][:-1] == 3) assert torch.all(batch.input_ids[0][:-1] == 3)
assert batch.attention_mask[0][-1] == 1 assert batch.attention_mask[0][0] == 1
assert torch.all(batch.attention_mask[0][:-1] == 0) assert torch.all(batch.attention_mask[0][1:] == 0)
assert batch.past_key_values is None assert batch.past_key_values is None
@ -98,16 +98,13 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert not next_batch.keys_head_dim_last assert not next_batch.keys_head_dim_last
assert len(next_batch.all_input_ids) == next_batch.size assert len(next_batch.all_input_ids) == next_batch.size
assert ( assert len(next_batch.all_input_ids[0]) == sequence_length + 1
len(next_batch.all_input_ids[0]) assert len(next_batch.attention_mask[0]) == 11
== len(next_batch.attention_mask[0])
== sequence_length + 1
)
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264) assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
assert torch.all(next_batch.all_input_ids[0][:-2] == 3) assert torch.all(next_batch.all_input_ids[0][:-2] == 3)
assert torch.all(next_batch.attention_mask[0][-2:] == 1) assert torch.all(next_batch.attention_mask[0][:2] == 1)
assert torch.all(next_batch.attention_mask[0][:-2] == 0) assert torch.all(next_batch.attention_mask[0][2:] == 0)
assert next_batch.input_ids.shape == (next_batch.size, 1) assert next_batch.input_ids.shape == (next_batch.size, 1)
assert next_batch.input_ids[0, 0] == 10264 assert next_batch.input_ids[0, 0] == 10264
@ -213,9 +210,13 @@ def test_batch_concatenate(
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0]) assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1]) assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
assert torch.all(next_batch.attention_mask[0] == 1) assert torch.all(
assert torch.all(next_batch.attention_mask[1:, -2:] == 1) next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1
assert torch.all(next_batch.attention_mask[1:, :-2] == 0) )
assert torch.all(
next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1
)
assert torch.all(next_batch.attention_mask[1:, 3:] == 0)
assert next_batch.batch_id == 0 assert next_batch.batch_id == 0
assert torch.all(next_batch.input_ids == 10264) assert torch.all(next_batch.input_ids == 10264)

View File

@ -62,8 +62,8 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
assert batch.input_ids[0][-1] == 14402 assert batch.input_ids[0][-1] == 14402
assert torch.all(batch.input_ids[0][:-1] == 50256) assert torch.all(batch.input_ids[0][:-1] == 50256)
assert batch.attention_mask[0][-1] == 1 assert batch.attention_mask[0, 0] == 1
assert torch.all(batch.attention_mask[0][:-1] == 0) assert torch.all(batch.attention_mask[0, 1:] == 0)
assert batch.past_key_values is None assert batch.past_key_values is None
@ -94,17 +94,14 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert isinstance(next_batch, CausalLMBatch) assert isinstance(next_batch, CausalLMBatch)
assert len(next_batch.all_input_ids) == next_batch.size assert len(next_batch.all_input_ids) == next_batch.size
assert ( assert len(next_batch.all_input_ids[0]) == sequence_length + 1
len(next_batch.all_input_ids[0]) assert len(next_batch.attention_mask[0]) == 11
== len(next_batch.attention_mask[0])
== sequence_length + 1
)
assert next_batch.all_input_ids[0][-1] == 13 assert next_batch.all_input_ids[0][-1] == 13
assert next_batch.all_input_ids[0][-2] == 14402 assert next_batch.all_input_ids[0][-2] == 14402
assert torch.all(next_batch.all_input_ids[0][:-2] == 50256) assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)
assert torch.all(next_batch.attention_mask[0][-2:] == 1) assert torch.all(next_batch.attention_mask[0][0:2] == 1)
assert torch.all(next_batch.attention_mask[0][:-2] == 0) assert torch.all(next_batch.attention_mask[0][2:] == 0)
assert next_batch.input_ids.shape == (next_batch.size, 1) assert next_batch.input_ids.shape == (next_batch.size, 1)
assert next_batch.input_ids[0, 0] == 13 assert next_batch.input_ids[0, 0] == 13
@ -210,9 +207,13 @@ def test_batch_concatenate(
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0]) assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1]) assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
assert torch.all(next_batch.attention_mask[0] == 1) assert torch.all(
assert torch.all(next_batch.attention_mask[1:, -2:] == 1) next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1
assert torch.all(next_batch.attention_mask[1:, :-2] == 0) )
assert torch.all(
next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1
)
assert torch.all(next_batch.attention_mask[1:, 3:] == 0)
assert next_batch.batch_id == 0 assert next_batch.batch_id == 0
assert next_batch.input_ids[0, 0] == 12355 assert next_batch.input_ids[0, 0] == 12355

View File

@ -106,7 +106,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
assert isinstance(next_batch, Seq2SeqLMBatch) assert isinstance(next_batch, Seq2SeqLMBatch)
assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids) assert next_batch.input_ids is None
assert torch.equal( assert torch.equal(
next_batch.attention_mask, default_seq2seq_lm_batch.attention_mask next_batch.attention_mask, default_seq2seq_lm_batch.attention_mask
) )
@ -220,11 +220,6 @@ def test_batch_concatenate(
assert next_batch.batch_id == 0 assert next_batch.batch_id == 0
assert torch.all(next_batch.input_ids[:, 0] == 4268)
assert torch.all(next_batch.input_ids[:, 1] == 1)
assert torch.all(next_batch.attention_mask == 1)
assert torch.equal( assert torch.equal(
next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0] next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]
) )
@ -233,9 +228,10 @@ def test_batch_concatenate(
next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids
) )
assert torch.all(next_batch.decoder_attention_mask[0] == 1) assert torch.all(next_batch.decoder_attention_mask[0, :3] == 1)
assert torch.all(next_batch.decoder_attention_mask[0, 3:] == 0)
assert torch.all(next_batch.decoder_attention_mask[1:, 0] == 0) assert torch.all(next_batch.decoder_attention_mask[1:, 0] == 0)
assert torch.all(next_batch.decoder_attention_mask[1:, -2:] == 1) assert torch.all(next_batch.decoder_attention_mask[1:, 1:3] == 1)
assert torch.equal( assert torch.equal(
next_batch.encoder_last_hidden_state[0], next_batch.encoder_last_hidden_state[0],

View File

@ -37,6 +37,7 @@ class CausalLMBatch(Batch):
# Metadata used for padding # Metadata used for padding
size: int size: int
max_sequence_length: int max_sequence_length: int
padding_right_offset: int
# Past metadata # Past metadata
keys_head_dim_last: bool = True keys_head_dim_last: bool = True
@ -61,22 +62,36 @@ class CausalLMBatch(Batch):
input_lengths = [] input_lengths = []
# Parse batch # Parse batch
max_sequence_length = 0
padding_right_offset = 0
for r in pb.requests: for r in pb.requests:
inputs.append(r.inputs) inputs.append(r.inputs)
input_lengths.append(r.input_length) input_lengths.append(r.input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criterias.append( stopping_criteria = StoppingCriteria.from_pb(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
max_sequence_length = max(max_sequence_length, r.input_length)
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
) )
pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
inputs, inputs,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=False, return_token_type_ids=False,
).to(device) ).to(device)
input_ids = tokenized_inputs["input_ids"]
# Allocate maximum attention_mask
attention_mask = input_ids.new_zeros(
(pb.size, max_sequence_length + padding_right_offset)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
@ -84,8 +99,8 @@ class CausalLMBatch(Batch):
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
input_ids=tokenized_inputs["input_ids"], input_ids=input_ids,
attention_mask=tokenized_inputs["attention_mask"], attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_values=None, past_key_values=None,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
@ -93,15 +108,21 @@ class CausalLMBatch(Batch):
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=pb.size, size=pb.size,
max_sequence_length=max(input_lengths), max_sequence_length=max_sequence_length,
padding_right_offset=padding_right_offset,
) )
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
# Used for padding # Used for padding
total_batch_size = sum(batch.size for batch in batches) total_batch_size = 0
max_sequence_length = max(batch.max_sequence_length for batch in batches) max_sequence_length = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += batch.size
max_sequence_length = max(max_sequence_length, batch.max_sequence_length)
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes # Batch attributes
requests = [] requests = []
@ -144,13 +165,22 @@ class CausalLMBatch(Batch):
# Create padded tensor # Create padded tensor
if attention_mask is None: if attention_mask is None:
attention_mask = batch.attention_mask.new_zeros( attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_sequence_length), (total_batch_size, max_sequence_length + padding_right_offset),
) )
# We need to slice the attention mask to remove padding from previous steps # We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space
left_offset = max_sequence_length - batch.max_sequence_length
batch_left_offset = (
batch.attention_mask.shape[1] - batch.max_sequence_length - batch.padding_right_offset
)
attention_mask[ attention_mask[
start_index:end_index, -batch.max_sequence_length : start_index:end_index,
] = batch.attention_mask[:, -batch.max_sequence_length :] left_offset:-padding_right_offset,
] = batch.attention_mask[
:,
batch_left_offset : -batch.padding_right_offset,
]
# Create empty tensor # Create empty tensor
# position_ids is always of shape [batch_size, 1] # position_ids is always of shape [batch_size, 1]
@ -228,6 +258,7 @@ class CausalLMBatch(Batch):
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=total_batch_size, size=total_batch_size,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last, keys_head_dim_last=batches[0].keys_head_dim_last,
) )
@ -294,9 +325,12 @@ class CausalLM(Model):
def generate_token( def generate_token(
self, batch: CausalLMBatch self, batch: CausalLMBatch
) -> Tuple[List[Generation], Optional[CausalLMBatch]]: ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
# slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
logits, past = self.forward( logits, past = self.forward(
batch.input_ids, batch.input_ids,
batch.attention_mask, attention_mask,
batch.position_ids, batch.position_ids,
batch.past_key_values, batch.past_key_values,
) )
@ -448,14 +482,8 @@ class CausalLM(Model):
next_batch_next_token_choosers = batch.next_token_choosers next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias next_batch_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids # Update attention_mask as we added a new token to input_ids
next_batch_attention_mask = torch.cat( next_batch_attention_mask[:, -batch.padding_right_offset] = 1
[
next_batch_attention_mask,
next_batch_attention_mask.new_ones(next_batch_size, 1),
],
dim=1,
)
# Update position_ids # Update position_ids
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1 next_batch_position_ids = next_batch_position_ids[:, -1:] + 1
@ -473,6 +501,7 @@ class CausalLM(Model):
stopping_criterias=next_batch_stopping_criterias, stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size, size=next_batch_size,
max_sequence_length=next_batch_max_sequence_length, max_sequence_length=next_batch_max_sequence_length,
padding_right_offset=batch.padding_right_offset - 1,
keys_head_dim_last=batch.keys_head_dim_last, keys_head_dim_last=batch.keys_head_dim_last,
) )
return generations, next_batch return generations, next_batch

View File

@ -106,12 +106,10 @@ class GalacticaCausalLMBatch(CausalLMBatch):
) )
# Tokenize batch # Tokenize batch
pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
inputs, inputs,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=False, return_token_type_ids=False,
).to(device) ).to(device)
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1

View File

@ -42,6 +42,7 @@ class Seq2SeqLMBatch(Batch):
size: int size: int
max_input_length: int max_input_length: int
max_decoder_input_length: int max_decoder_input_length: int
padding_right_offset: int
def to_pb(self) -> generate_pb2.Batch: def to_pb(self) -> generate_pb2.Batch:
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf""" """Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
@ -68,6 +69,8 @@ class Seq2SeqLMBatch(Batch):
decoder_input_lengths = [] decoder_input_lengths = []
# Parse batch # Parse batch
max_input_length = 0
padding_right_offset = 0
for r in pb.requests: for r in pb.requests:
inputs.append(r.inputs) inputs.append(r.inputs)
input_lengths.append(r.input_length) input_lengths.append(r.input_length)
@ -75,17 +78,20 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids.append(tokenizer.bos_token_id) decoder_input_ids.append(tokenizer.bos_token_id)
decoder_input_lengths.append(1) decoder_input_lengths.append(1)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criterias.append( stopping_criteria = StoppingCriteria.from_pb(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
max_input_length = max(max_input_length, r.input_length)
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
) )
# Tokenize batch # Tokenize batch
pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
inputs, inputs,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=False, return_token_type_ids=False,
).to(device) ).to(device)
# Convert decoder_input_ids to torch tensor of size [batch_size, 1] # Convert decoder_input_ids to torch tensor of size [batch_size, 1]
@ -107,6 +113,7 @@ class Seq2SeqLMBatch(Batch):
size=len(pb.requests), size=len(pb.requests),
max_input_length=max(input_lengths), max_input_length=max(input_lengths),
max_decoder_input_length=1, max_decoder_input_length=1,
padding_right_offset=padding_right_offset,
) )
@classmethod @classmethod
@ -115,11 +122,17 @@ class Seq2SeqLMBatch(Batch):
"""Concatenate multiple batches together by padding internal torch tensors""" """Concatenate multiple batches together by padding internal torch tensors"""
# Used for padding # Used for padding
total_batch_size = sum(batch.size for batch in batches) total_batch_size = 0
max_input_length = max(batch.max_input_length for batch in batches) max_input_length = 0
max_decoder_input_length = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += batch.size
max_input_length = max(max_input_length, batch.max_input_length)
max_decoder_input_length = max( max_decoder_input_length = max(
batch.max_decoder_input_length for batch in batches max_decoder_input_length, batch.max_decoder_input_length
) )
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes # Batch attributes
requests = [] requests = []
@ -129,7 +142,6 @@ class Seq2SeqLMBatch(Batch):
stopping_criterias = [] stopping_criterias = []
# Batch tensors # Batch tensors
input_ids = None
attention_mask = None attention_mask = None
decoder_input_ids = None decoder_input_ids = None
decoder_attention_mask = None decoder_attention_mask = None
@ -155,16 +167,6 @@ class Seq2SeqLMBatch(Batch):
if batch.encoder_last_hidden_state is None: if batch.encoder_last_hidden_state is None:
raise ValueError("Batch encoder_last_hidden_state cannot be None") raise ValueError("Batch encoder_last_hidden_state cannot be None")
# Create padded tensor
if input_ids is None:
input_ids = batch.input_ids.new_zeros(
(total_batch_size, max_input_length),
)
# Copy to correct indices
input_ids[
start_index:end_index, -batch.max_input_length :
] = batch.input_ids[:, -batch.max_input_length :]
# Create padded tensor # Create padded tensor
if attention_mask is None: if attention_mask is None:
attention_mask = batch.attention_mask.new_zeros( attention_mask = batch.attention_mask.new_zeros(
@ -189,19 +191,29 @@ class Seq2SeqLMBatch(Batch):
if decoder_attention_mask is None: if decoder_attention_mask is None:
# As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
decoder_attention_mask = batch.attention_mask.new_zeros( decoder_attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_decoder_input_length), (total_batch_size, max_decoder_input_length + padding_right_offset),
) )
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
# this batch. All generations are of length `batch.max_decoder_input_length`. # this batch. All generations are of length `batch.max_decoder_input_length`.
left_offset = max_decoder_input_length - batch.max_decoder_input_length
if batch.decoder_attention_mask is None: if batch.decoder_attention_mask is None:
decoder_attention_mask[ decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length : start_index:end_index,
left_offset:-padding_right_offset,
] = 1 ] = 1
# If it exists, we need to index # If it exists, we need to index
else: else:
batch_left_offset = (
batch.decoder_attention_mask.shape[1]
- batch.max_decoder_input_length - batch.padding_right_offset
)
decoder_attention_mask[ decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length : start_index:end_index,
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :] left_offset:-padding_right_offset,
] = batch.decoder_attention_mask[
:,
batch_left_offset : -batch.padding_right_offset,
]
# Create padded tensor # Create padded tensor
if encoder_last_hidden_state is None: if encoder_last_hidden_state is None:
@ -273,7 +285,7 @@ class Seq2SeqLMBatch(Batch):
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,
input_ids=input_ids, input_ids=None,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
@ -286,6 +298,7 @@ class Seq2SeqLMBatch(Batch):
size=total_batch_size, size=total_batch_size,
max_input_length=max_input_length, max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length, max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset,
) )
def __len__(self): def __len__(self):
@ -342,14 +355,6 @@ class Seq2SeqLM(Model):
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]: ]:
# Model Forward # Model Forward
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1)
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally...
if encoder_last_hidden_state is not None:
encoder_last_hidden_state = [encoder_last_hidden_state]
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
@ -369,12 +374,34 @@ class Seq2SeqLM(Model):
def generate_token( def generate_token(
self, batch: Seq2SeqLMBatch self, batch: Seq2SeqLMBatch
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]: ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
if batch.decoder_attention_mask is not None:
# slice to the correct shape
decoder_attention_mask = batch.decoder_attention_mask[
:, : -batch.padding_right_offset
]
else:
decoder_attention_mask = None
# check if first forward or not
if batch.past_key_values is not None:
# Only take the last token
decoder_input_ids = batch.decoder_input_ids[:, -1].unsqueeze(-1)
else:
decoder_input_ids = batch.decoder_input_ids
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally...
if batch.encoder_last_hidden_state is not None:
encoder_last_hidden_state = [batch.encoder_last_hidden_state]
else:
encoder_last_hidden_state = batch.encoder_last_hidden_state
logits, encoder_last_hidden_state, past = self.forward( logits, encoder_last_hidden_state, past = self.forward(
batch.input_ids, batch.input_ids,
batch.attention_mask, batch.attention_mask,
batch.decoder_input_ids, decoder_input_ids,
batch.decoder_attention_mask, decoder_attention_mask,
batch.encoder_last_hidden_state, encoder_last_hidden_state,
batch.past_key_values, batch.past_key_values,
) )
@ -402,7 +429,6 @@ class Seq2SeqLM(Model):
logits, logits,
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
batch.input_ids,
batch.decoder_input_ids, batch.decoder_input_ids,
) )
@ -414,7 +440,6 @@ class Seq2SeqLM(Model):
logits, logits,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
input_tokens,
decoder_input_ids, decoder_input_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
@ -500,10 +525,8 @@ class Seq2SeqLM(Model):
# If we finished at least one generation, we need to evict the indices of the generations that finished # If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch # from the values of the next batch
if len(next_batch_keep_indices) != len(batch): if len(next_batch_keep_indices) != len(batch):
# Apply indices to attention mask, past key values and other items that need to be cached # Apply indices to decoder_attention mask, past key values and other items that need to be cached
next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
if batch.decoder_attention_mask is not None: if batch.decoder_attention_mask is not None:
next_batch_decoder_attention_mask = batch.decoder_attention_mask[ next_batch_decoder_attention_mask = batch.decoder_attention_mask[
next_batch_keep_indices next_batch_keep_indices
@ -526,7 +549,6 @@ class Seq2SeqLM(Model):
batch.stopping_criterias[i] for i in next_batch_keep_indices batch.stopping_criterias[i] for i in next_batch_keep_indices
] ]
else: else:
next_batch_input_ids = batch.input_ids
next_batch_attention_mask = batch.attention_mask next_batch_attention_mask = batch.attention_mask
next_batch_decoder_attention_mask = batch.decoder_attention_mask next_batch_decoder_attention_mask = batch.decoder_attention_mask
next_batch_encoder_last_hidden_state = encoder_last_hidden_state next_batch_encoder_last_hidden_state = encoder_last_hidden_state
@ -536,20 +558,14 @@ class Seq2SeqLM(Model):
next_batch_next_token_choosers = batch.next_token_choosers next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias next_batch_stopping_criterias = batch.stopping_criterias
# Update decoder_attention_mask with padding as we added a new token to input_ids # Update decoder_attention_mask as we added a new token to input_ids
if next_batch_decoder_attention_mask is not None: if next_batch_decoder_attention_mask is not None:
next_batch_decoder_attention_mask = torch.cat( next_batch_decoder_attention_mask[:, -batch.padding_right_offset] = 1
[
next_batch_decoder_attention_mask,
next_batch_decoder_attention_mask.new_ones(next_batch_size, 1),
],
dim=1,
)
next_batch = Seq2SeqLMBatch( next_batch = Seq2SeqLMBatch(
batch_id=batch.batch_id, batch_id=batch.batch_id,
requests=next_batch_requests, requests=next_batch_requests,
input_ids=next_batch_input_ids, input_ids=None,
attention_mask=next_batch_attention_mask, attention_mask=next_batch_attention_mask,
decoder_input_ids=next_batch_decoder_input_ids, decoder_input_ids=next_batch_decoder_input_ids,
decoder_attention_mask=next_batch_decoder_attention_mask, decoder_attention_mask=next_batch_decoder_attention_mask,
@ -562,5 +578,6 @@ class Seq2SeqLM(Model):
size=next_batch_size, size=next_batch_size,
max_input_length=next_batch_max_input_length, max_input_length=next_batch_max_input_length,
max_decoder_input_length=next_batch_max_decoder_input_length, max_decoder_input_length=next_batch_max_decoder_input_length,
padding_right_offset=batch.padding_right_offset - 1,
) )
return generations, next_batch return generations, next_batch

View File

@ -221,14 +221,6 @@ class T5Sharded(Seq2SeqLM):
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]: ]:
# Model Forward # Model Forward
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1)
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally...
if encoder_last_hidden_state is not None:
encoder_last_hidden_state = [encoder_last_hidden_state]
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,