Fixing top_n_tokens. (#1497)
# What does this PR do? Superseeds #1459 The fix works as follows. We updated next_token_chooser to return all logprbs, then batch_top_n_tokens, now also gets accepted_ids + speculated_length (so it knows how to interpret the flat logprobs). We then update the code to return lists ot `Tokens` that it expects. <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
This commit is contained in:
parent
c2d4a3b5c7
commit
069895b985
|
@ -50,19 +50,39 @@ def test_batch_top_tokens():
|
|||
top_n_tokens = [0, 2, 3, 4, 5]
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens)
|
||||
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5)
|
||||
accepted_ids = torch.ones_like(top_n_tokens_tensor)
|
||||
|
||||
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
|
||||
top_n_tokens, top_n_tokens_tensor, inp_logprobs
|
||||
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
|
||||
)
|
||||
|
||||
assert topn_tok_ids[0] == []
|
||||
assert topn_tok_ids[1] == [0, 3]
|
||||
assert topn_tok_ids[2] == [0, 3, 1, 4]
|
||||
assert topn_tok_ids[3] == [0, 3, 1, 4]
|
||||
assert topn_tok_ids[4] == [0, 3, 1, 4, 2]
|
||||
assert topn_tok_ids[0] == [[]]
|
||||
assert topn_tok_ids[1] == [[0, 3]]
|
||||
assert topn_tok_ids[2] == [[0, 3, 1, 4]]
|
||||
assert topn_tok_ids[3] == [[0, 3, 1, 4]]
|
||||
assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]
|
||||
|
||||
assert topn_tok_logprobs[0] == []
|
||||
assert topn_tok_logprobs[1] == [-1, -2]
|
||||
assert topn_tok_logprobs[2] == [-1, -2, -3, -3]
|
||||
assert topn_tok_logprobs[3] == [-1, -2, -3, -3]
|
||||
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4]
|
||||
assert topn_tok_logprobs[0] == [[]]
|
||||
assert topn_tok_logprobs[1] == [[-1, -2]]
|
||||
assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
|
||||
assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
|
||||
assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]
|
||||
|
||||
# Now let's make second member of the batch be speculated
|
||||
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2)
|
||||
accepted_ids[1] = 2
|
||||
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
|
||||
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
|
||||
)
|
||||
|
||||
assert topn_tok_ids[0] == [[]]
|
||||
assert topn_tok_ids[1] == [[0, 3], [0, 3]]
|
||||
assert topn_tok_ids[2] == [[0, 3, 1, 4]]
|
||||
assert topn_tok_ids[3] == [[0, 3, 1, 4]]
|
||||
assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]
|
||||
|
||||
assert topn_tok_logprobs[0] == [[]]
|
||||
assert topn_tok_logprobs[1] == [[-1, -2], [-1, -2]]
|
||||
assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
|
||||
assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
|
||||
assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]
|
||||
|
|
|
@ -580,10 +580,13 @@ class CausalLM(Model):
|
|||
generations: List[Generation] = []
|
||||
stopped = True
|
||||
|
||||
# Speculation is not active for causal
|
||||
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens,
|
||||
batch.top_n_tokens_tensor,
|
||||
torch.log_softmax(logits[:, -1], -1),
|
||||
accepted_ids,
|
||||
)
|
||||
|
||||
start_decode = time.time_ns()
|
||||
|
@ -692,20 +695,24 @@ class CausalLM(Model):
|
|||
prefill_tokens = None
|
||||
|
||||
if top_n_tokens > 0:
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = Tokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
all_top_tokens = []
|
||||
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = Tokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
all_top_tokens.append(top_tokens)
|
||||
top_tokens = all_top_tokens
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
|
|
|
@ -842,6 +842,8 @@ class FlashCausalLM(Model):
|
|||
else:
|
||||
next_token_logits = out
|
||||
|
||||
|
||||
speculate = get_speculate()
|
||||
(
|
||||
next_input_ids,
|
||||
next_token_logprobs,
|
||||
|
@ -851,16 +853,15 @@ class FlashCausalLM(Model):
|
|||
) = batch.next_token_chooser(
|
||||
batch.all_input_ids_tensor[:, : batch.max_seqlen],
|
||||
next_token_logits,
|
||||
get_speculate(),
|
||||
speculate,
|
||||
batch.speculative_ids,
|
||||
speculative_logits,
|
||||
)
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
|
||||
)
|
||||
|
||||
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
|
||||
if prefill:
|
||||
if len(batch) > 1 and prefill_logprobs:
|
||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||
|
@ -1062,20 +1063,24 @@ class FlashCausalLM(Model):
|
|||
prefill_tokens = None
|
||||
|
||||
if top_n_tokens > 0:
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = Tokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
all_top_tokens = []
|
||||
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = Tokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
all_top_tokens.append(top_tokens)
|
||||
top_tokens = all_top_tokens
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
|
|
|
@ -640,10 +640,13 @@ class Seq2SeqLM(Model):
|
|||
batch.past_key_values,
|
||||
)
|
||||
|
||||
# Speculation is not active for seq2seq
|
||||
accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0]
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens,
|
||||
batch.top_n_tokens_tensor,
|
||||
torch.log_softmax(logits[:, -1], -1),
|
||||
accepted_ids,
|
||||
)
|
||||
|
||||
start_decode = time.time_ns()
|
||||
|
@ -746,20 +749,24 @@ class Seq2SeqLM(Model):
|
|||
prefill_tokens = None
|
||||
|
||||
if top_n_tokens > 0:
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = Tokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
all_top_tokens = []
|
||||
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = Tokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
all_top_tokens.append(top_tokens)
|
||||
top_tokens = all_top_tokens
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
|
|
|
@ -95,5 +95,5 @@ class Generation:
|
|||
generated_text=self.generated_text.to_pb()
|
||||
if self.generated_text is not None
|
||||
else None,
|
||||
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
|
||||
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens] if self.top_tokens is not None else None,
|
||||
)
|
||||
|
|
|
@ -277,7 +277,8 @@ class HeterogeneousNextTokenChooser:
|
|||
scores[:, j] = _scores
|
||||
next_ids[:, j] = _next_ids
|
||||
next_ids = next_ids.view(B * S)
|
||||
scores = scores.view(B * S, -1)
|
||||
allscores = scores.view(B * S, -1)
|
||||
alllogprobs = torch.log_softmax(allscores, -1)
|
||||
|
||||
if speculated_ids is not None:
|
||||
accepted_ids = []
|
||||
|
@ -305,16 +306,17 @@ class HeterogeneousNextTokenChooser:
|
|||
accepted_ids, device=input_ids.device, dtype=input_ids.dtype
|
||||
)
|
||||
next_ids = next_ids[indices]
|
||||
scores = scores[indices]
|
||||
logprobs = alllogprobs[indices]
|
||||
indices = torch.arange(B, device=input_ids.device) * S
|
||||
if speculative_scores is not None:
|
||||
speculative_scores = speculative_scores[indices + accepted_ids - 1]
|
||||
else:
|
||||
accepted_ids = torch.ones_like(next_ids)
|
||||
logprobs = alllogprobs
|
||||
|
||||
logprobs = torch.log_softmax(scores, -1)
|
||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||
|
||||
|
||||
if speculate > 0:
|
||||
if speculative_scores is not None:
|
||||
# Medusa provided some scores
|
||||
|
@ -327,7 +329,7 @@ class HeterogeneousNextTokenChooser:
|
|||
else:
|
||||
speculative_ids = None
|
||||
|
||||
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids
|
||||
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
|
||||
|
||||
def filter(self, indices):
|
||||
if self.watermark_processor is not None:
|
||||
|
@ -436,8 +438,8 @@ class HeterogeneousSampling:
|
|||
|
||||
|
||||
def batch_top_tokens(
|
||||
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor
|
||||
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
|
||||
"""Find the top n most likely tokens for a batch of generations.
|
||||
|
||||
When multiple tokens have equal probabilities and they don't all fit, the
|
||||
|
@ -446,14 +448,19 @@ def batch_top_tokens(
|
|||
max_top_n = max(top_n_tokens)
|
||||
# Early exit when top_n_tokens is not used
|
||||
if max_top_n == 0:
|
||||
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
|
||||
return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
|
||||
|
||||
|
||||
batch_size = accepted_ids.shape[0]
|
||||
speculate_size = logprobs.shape[0] // batch_size
|
||||
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
|
||||
# Ensure top_n doesn't exceed vocab size
|
||||
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens]
|
||||
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculate_size)]
|
||||
|
||||
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
||||
# Sorted topk is faster than torch.sort() since we only need a small subset
|
||||
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
|
||||
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values
|
||||
|
||||
nth_highest = torch.gather(
|
||||
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
|
||||
)
|
||||
|
@ -471,13 +478,33 @@ def batch_top_tokens(
|
|||
top_indices = top_k.indices.tolist()
|
||||
top_values = top_k.values.tolist()
|
||||
|
||||
return (
|
||||
[
|
||||
idxs[:n] if req_n > 0 else []
|
||||
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)
|
||||
],
|
||||
[
|
||||
vals[:n] if req_n > 0 else []
|
||||
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)
|
||||
],
|
||||
)
|
||||
batch_top_token_ids = []
|
||||
batch_top_token_logprobs = []
|
||||
accepted_ids_list = accepted_ids.tolist()
|
||||
for i, n_accepted_ids in enumerate(accepted_ids_list):
|
||||
start = speculate_size * i
|
||||
stop = speculate_size * (i + 1)
|
||||
_top_indices = top_indices[start: stop]
|
||||
_top_values = top_values[start: stop]
|
||||
_top_n_ishes = top_n_ishes[start: stop]
|
||||
_top_n_tokens = top_n_tokens[start: stop]
|
||||
|
||||
_top_indices = _top_indices[:n_accepted_ids]
|
||||
_top_values = _top_values[:n_accepted_ids]
|
||||
_top_n_ishes = _top_n_ishes[:n_accepted_ids]
|
||||
_top_n_tokens = _top_n_tokens[:n_accepted_ids]
|
||||
|
||||
row_top_token_ids = []
|
||||
row_top_token_logprobs = []
|
||||
|
||||
for idxs, vals, n, req_n in zip(_top_indices, _top_values, _top_n_ishes, _top_n_tokens):
|
||||
indices = idxs[:n] if req_n > 0 else []
|
||||
values = vals[:n] if req_n > 0 else []
|
||||
|
||||
row_top_token_ids.append(indices)
|
||||
row_top_token_logprobs.append(values)
|
||||
|
||||
batch_top_token_ids.append(row_top_token_ids)
|
||||
batch_top_token_logprobs.append(row_top_token_logprobs)
|
||||
|
||||
return batch_top_token_ids, batch_top_token_logprobs
|
||||
|
|
Loading…
Reference in New Issue