fix(server): Fix position ids (#28)

This commit is contained in:
OlivierDehaene 2023-01-20 15:35:22 +01:00 committed by GitHub
parent 15511edc01
commit 1f570d181f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 33 additions and 36 deletions

View File

@ -42,6 +42,7 @@ def default_fim_pb_batch(default_fim_pb_request):
return generate_pb2.Batch(id=0, requests=[default_fim_pb_request], size=1) return generate_pb2.Batch(id=0, requests=[default_fim_pb_request], size=1)
@pytest.mark.skip
def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch): def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):
batch = CausalLMBatch.from_pb( batch = CausalLMBatch.from_pb(
default_pb_batch, default_santacoder.tokenizer, default_santacoder.device default_pb_batch, default_santacoder.tokenizer, default_santacoder.device
@ -65,6 +66,7 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
) )
@pytest.mark.skip
def test_fim_santacoder_generate_token_completion( def test_fim_santacoder_generate_token_completion(
default_santacoder, default_fim_pb_batch default_santacoder, default_fim_pb_batch
): ):

View File

@ -236,10 +236,11 @@ class BLOOMSharded(BLOOM):
if name == "word_embeddings.weight": if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor model.lm_head._parameters["weight"] = tensor
def forward(self, input_ids, attention_mask, past_key_values: Optional = None): def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None):
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
) )

View File

@ -18,6 +18,7 @@ class CausalLMBatch(Batch):
# Decoder values # Decoder values
input_ids: torch.Tensor input_ids: torch.Tensor
attention_mask: torch.Tensor attention_mask: torch.Tensor
position_ids: torch.Tensor
past_key_values: Optional[List[Tuple]] past_key_values: Optional[List[Tuple]]
# All tokens # All tokens
@ -76,6 +77,8 @@ class CausalLMBatch(Batch):
pad_to_multiple_of=pad_to_multiple_of, pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=False, return_token_type_ids=False,
).to(device) ).to(device)
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
return cls( return cls(
@ -83,6 +86,7 @@ class CausalLMBatch(Batch):
requests=pb.requests, requests=pb.requests,
input_ids=tokenized_inputs["input_ids"], input_ids=tokenized_inputs["input_ids"],
attention_mask=tokenized_inputs["attention_mask"], attention_mask=tokenized_inputs["attention_mask"],
position_ids=position_ids,
past_key_values=None, past_key_values=None,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_logprobs=all_logprobs, all_logprobs=all_logprobs,
@ -110,6 +114,7 @@ class CausalLMBatch(Batch):
# Batch tensors # Batch tensors
input_ids = None input_ids = None
attention_mask = None attention_mask = None
position_ids = None
past_key_values = [] past_key_values = []
# Used for slicing correctly inside the tensors # Used for slicing correctly inside the tensors
@ -149,6 +154,12 @@ class CausalLMBatch(Batch):
start_index:end_index, -batch.max_sequence_length : start_index:end_index, -batch.max_sequence_length :
] = batch.attention_mask[:, -batch.max_sequence_length :] ] = batch.attention_mask[:, -batch.max_sequence_length :]
# Create empty tensor
# position_ids is always of shape [batch_size, 1]
if position_ids is None:
position_ids = batch.position_ids.new_empty((total_batch_size, 1))
position_ids[start_index:end_index] = batch.position_ids
for j, past in enumerate(batch.past_key_values): for j, past in enumerate(batch.past_key_values):
past_keys, past_values = past past_keys, past_values = past
@ -211,6 +222,7 @@ class CausalLMBatch(Batch):
requests=requests, requests=requests,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_logprobs=all_logprobs, all_logprobs=all_logprobs,
@ -263,12 +275,13 @@ class CausalLM(Model):
) )
def forward( def forward(
self, input_ids, attention_mask, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward # Model Forward
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
) )
@ -283,7 +296,7 @@ class CausalLM(Model):
) )
with context_manager(): with context_manager():
logits, past = self.forward( logits, past = self.forward(
batch.input_ids, batch.attention_mask, batch.past_key_values batch.input_ids, batch.attention_mask, batch.position_ids, batch.past_key_values
) )
# List of indices to cache # List of indices to cache
@ -356,7 +369,7 @@ class CausalLM(Model):
token_ids = all_input_ids[-new_input_length:] token_ids = all_input_ids[-new_input_length:]
tokens = self.tokenizer.batch_decode(token_ids) tokens = self.tokenizer.batch_decode(token_ids)
# Add NaN for the first prompt token # Add NaN for the first prompt token
logprobs = [float("nan")] + all_logprobs[-new_input_length:].squeeze( logprobs = [float("nan")] + all_logprobs[-input_length:].squeeze(
1 1
).tolist() ).tolist()
@ -394,6 +407,7 @@ class CausalLM(Model):
if generated_texts: if generated_texts:
# Apply indices to attention mask, past key values and other items that need to be cached # Apply indices to attention mask, past key values and other items that need to be cached
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing # Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_past_key_values = [ next_batch_past_key_values = [
[ [
@ -411,6 +425,7 @@ class CausalLM(Model):
] ]
else: else:
next_batch_attention_mask = batch.attention_mask next_batch_attention_mask = batch.attention_mask
next_batch_position_ids = batch.position_ids
next_batch_past_key_values = past next_batch_past_key_values = past
next_batch_requests = batch.requests next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers next_batch_next_token_choosers = batch.next_token_choosers
@ -425,11 +440,15 @@ class CausalLM(Model):
dim=1, dim=1,
) )
# Update position_ids
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1
next_batch = CausalLMBatch( next_batch = CausalLMBatch(
batch_id=batch.batch_id, batch_id=batch.batch_id,
requests=next_batch_requests, requests=next_batch_requests,
input_ids=next_batch_input_ids, input_ids=next_batch_input_ids,
attention_mask=next_batch_attention_mask, attention_mask=next_batch_attention_mask,
position_ids=next_batch_position_ids,
past_key_values=next_batch_past_key_values, past_key_values=next_batch_past_key_values,
all_input_ids=next_batch_all_input_ids, all_input_ids=next_batch_all_input_ids,
all_logprobs=next_batch_all_logprobs, all_logprobs=next_batch_all_logprobs,

View File

@ -116,6 +116,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
pad_to_multiple_of=pad_to_multiple_of, pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=False, return_token_type_ids=False,
).to(device) ).to(device)
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
return cls( return cls(
@ -123,6 +125,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests=pb.requests, requests=pb.requests,
input_ids=tokenized_inputs["input_ids"], input_ids=tokenized_inputs["input_ids"],
attention_mask=tokenized_inputs["attention_mask"], attention_mask=tokenized_inputs["attention_mask"],
position_ids=position_ids,
past_key_values=None, past_key_values=None,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
input_lengths=input_lengths, input_lengths=input_lengths,
@ -330,10 +333,11 @@ class GalacticaSharded(Galactica):
if name == "model.decoder.embed_tokens.weight": if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor model.lm_head._parameters["weight"] = tensor
def forward(self, input_ids, attention_mask, past_key_values: Optional = None): def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None):
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
) )

View File

@ -42,10 +42,9 @@ class SantaCoder(CausalLM):
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_name, model_name,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize, load_in_8bit=quantize,
trust_remote_code=True, # required trust_remote_code=True, # required
).eval() ).to(device).eval()
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -57,31 +56,3 @@ class SantaCoder(CausalLM):
return self.tokenizer.decode( return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
) )
def forward(
self, input_ids, attention_mask, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# FIXME: current forward with past is bugged for bigcode/santacoder because past_key_values does not have
# the correct shape ([batch_size, D, seq_length] instead of [batch_size, seq_length D]
# this leads to position_ids being wrong
input_length = input_ids.shape[-1]
past_key_values_length = (
0 if past_key_values is None else past_key_values[0][0].shape[-1]
)
position_ids = torch.arange(
past_key_values_length,
input_length + past_key_values_length,
dtype=torch.long,
device=input_ids.device,
).view(1, input_length)
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
use_cache=True,
)
return outputs.logits, outputs.past_key_values

View File

@ -449,7 +449,7 @@ class Seq2SeqLM(Model):
tokens = self.tokenizer.batch_decode(token_ids) tokens = self.tokenizer.batch_decode(token_ids)
# Add NaN for the bos token # Add NaN for the bos token
logprobs = [float("nan")] + decoder_logprobs[ logprobs = [float("nan")] + decoder_logprobs[
-new_decoder_input_length: -decoder_input_length:
].tolist() ].tolist()
# Add to the list of finished generations with the original request # Add to the list of finished generations with the original request
generated_texts.append( generated_texts.append(