fix(server): Fix position ids (#28)

This commit is contained in:
OlivierDehaene 2023-01-20 15:35:22 +01:00 committed by GitHub
parent 15511edc01
commit 1f570d181f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 33 additions and 36 deletions

View File

@ -42,6 +42,7 @@ def default_fim_pb_batch(default_fim_pb_request):
return generate_pb2.Batch(id=0, requests=[default_fim_pb_request], size=1)
@pytest.mark.skip
def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):
batch = CausalLMBatch.from_pb(
default_pb_batch, default_santacoder.tokenizer, default_santacoder.device
@ -65,6 +66,7 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
)
@pytest.mark.skip
def test_fim_santacoder_generate_token_completion(
default_santacoder, default_fim_pb_batch
):

View File

@ -236,10 +236,11 @@ class BLOOMSharded(BLOOM):
if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor
def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
)

View File

@ -18,6 +18,7 @@ class CausalLMBatch(Batch):
# Decoder values
input_ids: torch.Tensor
attention_mask: torch.Tensor
position_ids: torch.Tensor
past_key_values: Optional[List[Tuple]]
# All tokens
@ -76,6 +77,8 @@ class CausalLMBatch(Batch):
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=False,
).to(device)
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
return cls(
@ -83,6 +86,7 @@ class CausalLMBatch(Batch):
requests=pb.requests,
input_ids=tokenized_inputs["input_ids"],
attention_mask=tokenized_inputs["attention_mask"],
position_ids=position_ids,
past_key_values=None,
all_input_ids=all_input_ids,
all_logprobs=all_logprobs,
@ -110,6 +114,7 @@ class CausalLMBatch(Batch):
# Batch tensors
input_ids = None
attention_mask = None
position_ids = None
past_key_values = []
# Used for slicing correctly inside the tensors
@ -149,6 +154,12 @@ class CausalLMBatch(Batch):
start_index:end_index, -batch.max_sequence_length :
] = batch.attention_mask[:, -batch.max_sequence_length :]
# Create empty tensor
# position_ids is always of shape [batch_size, 1]
if position_ids is None:
position_ids = batch.position_ids.new_empty((total_batch_size, 1))
position_ids[start_index:end_index] = batch.position_ids
for j, past in enumerate(batch.past_key_values):
past_keys, past_values = past
@ -211,6 +222,7 @@ class CausalLMBatch(Batch):
requests=requests,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
all_input_ids=all_input_ids,
all_logprobs=all_logprobs,
@ -263,12 +275,13 @@ class CausalLM(Model):
)
def forward(
self, input_ids, attention_mask, past_key_values: Optional = None
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
)
@ -283,7 +296,7 @@ class CausalLM(Model):
)
with context_manager():
logits, past = self.forward(
batch.input_ids, batch.attention_mask, batch.past_key_values
batch.input_ids, batch.attention_mask, batch.position_ids, batch.past_key_values
)
# List of indices to cache
@ -356,7 +369,7 @@ class CausalLM(Model):
token_ids = all_input_ids[-new_input_length:]
tokens = self.tokenizer.batch_decode(token_ids)
# Add NaN for the first prompt token
logprobs = [float("nan")] + all_logprobs[-new_input_length:].squeeze(
logprobs = [float("nan")] + all_logprobs[-input_length:].squeeze(
1
).tolist()
@ -394,6 +407,7 @@ class CausalLM(Model):
if generated_texts:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_past_key_values = [
[
@ -411,6 +425,7 @@ class CausalLM(Model):
]
else:
next_batch_attention_mask = batch.attention_mask
next_batch_position_ids = batch.position_ids
next_batch_past_key_values = past
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
@ -425,11 +440,15 @@ class CausalLM(Model):
dim=1,
)
# Update position_ids
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1
next_batch = CausalLMBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
input_ids=next_batch_input_ids,
attention_mask=next_batch_attention_mask,
position_ids=next_batch_position_ids,
past_key_values=next_batch_past_key_values,
all_input_ids=next_batch_all_input_ids,
all_logprobs=next_batch_all_logprobs,

View File

@ -116,6 +116,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=False,
).to(device)
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
return cls(
@ -123,6 +125,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests=pb.requests,
input_ids=tokenized_inputs["input_ids"],
attention_mask=tokenized_inputs["attention_mask"],
position_ids=position_ids,
past_key_values=None,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
@ -330,10 +333,11 @@ class GalacticaSharded(Galactica):
if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor
def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
)

View File

@ -42,10 +42,9 @@ class SantaCoder(CausalLM):
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize,
trust_remote_code=True, # required
).eval()
).to(device).eval()
super(CausalLM, self).__init__(
tokenizer=tokenizer,
@ -57,31 +56,3 @@ class SantaCoder(CausalLM):
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
)
def forward(
self, input_ids, attention_mask, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# FIXME: current forward with past is bugged for bigcode/santacoder because past_key_values does not have
# the correct shape ([batch_size, D, seq_length] instead of [batch_size, seq_length D]
# this leads to position_ids being wrong
input_length = input_ids.shape[-1]
past_key_values_length = (
0 if past_key_values is None else past_key_values[0][0].shape[-1]
)
position_ids = torch.arange(
past_key_values_length,
input_length + past_key_values_length,
dtype=torch.long,
device=input_ids.device,
).view(1, input_length)
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
use_cache=True,
)
return outputs.logits, outputs.past_key_values

View File

@ -449,7 +449,7 @@ class Seq2SeqLM(Model):
tokens = self.tokenizer.batch_decode(token_ids)
# Add NaN for the bos token
logprobs = [float("nan")] + decoder_logprobs[
-new_decoder_input_length:
-decoder_input_length:
].tolist()
# Add to the list of finished generations with the original request
generated_texts.append(