hf_text-generation-inference/server/text_generation_server/utils/tokens.py

647 lines
23 KiB
Python
Raw Normal View History

2023-02-14 05:02:16 -07:00
import re
Use the generation config. (#1808) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2024-04-25 11:41:50 -06:00
from typing import List, Optional, Tuple, Set, Union
2023-02-14 05:02:16 -07:00
import math
Rebased #617 (#868) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Vincent Brouwers <vincent.brouwers@ing.com>
2023-08-28 03:43:47 -06:00
import torch
2023-03-07 10:52:22 -07:00
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType
from text_generation_server.utils.logits_process import (
FrequencyPenaltyLogitsProcessor,
GrammarLogitProcessor,
Rebased #617 (#868) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Vincent Brouwers <vincent.brouwers@ing.com>
2023-08-28 03:43:47 -06:00
HeterogeneousProcessorWrapper,
HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousFrequencyPenaltyLogitsProcessor,
HeterogeneousTemperatureLogitsWarper,
HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper,
HeterogeneousTypicalLogitsWarper,
HeterogeneousGrammarLogitProcessor,
Rebased #617 (#868) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Vincent Brouwers <vincent.brouwers@ing.com>
2023-08-28 03:43:47 -06:00
static_warper,
)
Rebased #617 (#868) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Vincent Brouwers <vincent.brouwers@ing.com>
2023-08-28 03:43:47 -06:00
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
2023-12-11 06:49:52 -07:00
class NextTokenChooser:
def __init__(
self,
watermark: bool = False,
temperature: float = 1.0,
repetition_penalty: float = 1.0,
frequency_penalty: float = 0.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
typical_p: Optional[float] = None,
do_sample: bool = False,
seed: int = 0,
device: str = "cpu",
tokenizer: Optional[PreTrainedTokenizerBase] = None,
grammar: str = "",
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
fsm_grammar_state: int = 0,
):
self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None
)
self.repetition_processor = (
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
if repetition_penalty and repetition_penalty != 1.0
else None
)
self.frequency_processor = (
FrequencyPenaltyLogitsProcessor(penalty=frequency_penalty)
if frequency_penalty and frequency_penalty != 0.0
else None
)
self.grammar_processor = (
GrammarLogitProcessor(tokenizer, device, grammar, grammar_type)
if grammar != ""
else None
)
self.tokenizer = tokenizer
has_warpers = (
(temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0)
or (top_p is not None and top_p < 1.0)
or (typical_p is not None and typical_p < 1.0)
)
if has_warpers:
self.static_warper = static_warper(
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
)
else:
self.static_warper = None
sampling = do_sample or has_warpers
2023-02-14 05:02:16 -07:00
self.choice = Sampling(seed, device) if sampling else Greedy()
self.fsm_grammar_state = fsm_grammar_state
self.grammar = grammar
2023-02-14 05:02:16 -07:00
def __call__(self, input_ids, scores):
if self.watermark_processor is not None:
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores)
if self.frequency_processor is not None:
scores = self.frequency_processor(input_ids, scores)
if self.grammar_processor is not None:
scores = self.grammar_processor(scores, self.fsm_grammar_state)
2023-02-14 05:02:16 -07:00
if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1)
else:
scores, next_logprob = self.static_warper(scores)
2023-02-14 05:02:16 -07:00
next_id = self.choice(scores[-1]).view(1, 1)
2023-02-14 05:02:16 -07:00
return next_id, next_logprob
2023-02-14 05:02:16 -07:00
def advance_grammar(self, next_id: int):
if self.grammar_processor is not None:
self.fsm_grammar_state = self.grammar_processor.advance(
next_id, self.fsm_grammar_state
)
return self
2023-02-14 05:02:16 -07:00
@classmethod
def from_pb(
cls,
pb: generate_pb2.NextTokenChooserParameters,
device: torch.device,
tokenizer: PreTrainedTokenizerBase,
2023-02-14 05:02:16 -07:00
) -> "NextTokenChooser":
return NextTokenChooser(
watermark=pb.watermark,
2023-02-14 05:02:16 -07:00
temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,
frequency_penalty=pb.frequency_penalty,
2023-02-14 05:02:16 -07:00
top_k=pb.top_k,
top_p=pb.top_p,
typical_p=pb.typical_p,
2023-02-14 05:02:16 -07:00
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
tokenizer=tokenizer,
grammar=pb.grammar,
grammar_type=pb.grammar_type,
2023-02-14 05:02:16 -07:00
)
class StopSequenceCriteria:
def __init__(self, stop_sequence: str):
stop_sequence = re.escape(stop_sequence)
self.regex = re.compile(f"{stop_sequence}$")
2023-02-14 05:02:16 -07:00
def __call__(self, output: str) -> bool:
if self.regex.findall(output):
return True
return False
class StoppingCriteria:
def __init__(
self,
Use the generation config. (#1808) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2024-04-25 11:41:50 -06:00
eos_token_ids: Optional[Union[Set[int], int]],
2023-02-14 05:02:16 -07:00
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens: int = 20,
ignore_eos_token: bool = False,
2023-02-14 05:02:16 -07:00
):
Use the generation config. (#1808) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2024-04-25 11:41:50 -06:00
if eos_token_ids is None:
eos_token_ids = set()
elif isinstance(eos_token_ids, int):
eos_token_ids = set([eos_token_ids])
elif isinstance(eos_token_ids, set):
eos_token_ids = eos_token_ids
else:
raise RuntimeError(
f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]"
)
self.eos_token_ids = eos_token_ids
2023-02-14 05:02:16 -07:00
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = ""
self.ignore_eos_token = ignore_eos_token
2023-02-14 05:02:16 -07:00
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
Use the generation config. (#1808) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2024-04-25 11:41:50 -06:00
if isinstance(last_token, torch.Tensor):
last_token = last_token.item()
if not self.ignore_eos_token and last_token in self.eos_token_ids:
2023-02-14 05:02:16 -07:00
return True, FinishReason.FINISH_REASON_EOS_TOKEN
if self.stop_sequence_criterias:
self.current_output += last_output
# There is no need to keep an output that is too long
if len(self.current_output) > 300:
# Slice to -200 to avoid doing it all the time
self.current_output = self.current_output[-200:]
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(self.current_output):
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
2023-02-14 05:02:16 -07:00
return False, None
@classmethod
def from_pb(
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria":
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
Use the generation config. (#1808) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2024-04-25 11:41:50 -06:00
# TODO Hack because eos_token_id cannot be what we want.
eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id)
2023-02-14 05:02:16 -07:00
return StoppingCriteria(
Use the generation config. (#1808) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2024-04-25 11:41:50 -06:00
eos_token_id,
stop_sequence_criterias,
pb.max_new_tokens,
pb.ignore_eos_token,
2023-02-14 05:02:16 -07:00
)
2023-12-11 06:49:52 -07:00
def create_n_gram_speculation(
input_ids: torch.Tensor,
next_ids: torch.Tensor,
accepted_ids: torch.Tensor,
speculate: int,
verbose: bool,
):
2023-12-11 04:46:30 -07:00
# Very trivial approach, find first match in the string.
# This is much less refined than actual n-gram but seems to work
# relatively OK in grounded mode and is by far much faster with
# much less worst case complexity as everything happens on device.
B = accepted_ids.shape[0]
device = input_ids.device
2023-12-11 06:49:52 -07:00
seeds = next_ids[accepted_ids.cumsum(dim=-1) - 1]
2023-12-11 04:46:30 -07:00
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
2023-12-11 06:49:52 -07:00
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(
speculate, device=device
)
2023-12-11 04:46:30 -07:00
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
speculative_ids = input_ids.gather(dim=-1, index=all_indices)
return speculative_ids
2023-12-11 06:49:52 -07:00
class HeterogeneousNextTokenChooser:
def __init__(
self,
dtype: torch.dtype,
device: torch.device,
watermark: List[bool],
temperature: List[float],
repetition_penalty: List[float],
frequency_penalty: List[float],
top_k: List[int],
top_p: List[float],
typical_p: List[float],
do_sample: List[bool],
seeds: List[int],
tokenizer: PreTrainedTokenizerBase,
grammars: List[str],
grammar_types: List[int],
fsm_grammar_states=List[int],
):
warpers = []
self.watermark_processor = (
HeterogeneousProcessorWrapper(
{
i: WatermarkLogitsProcessor(device=device)
for i, do_watermark in enumerate(watermark)
if do_watermark
}
)
if any(watermark)
else None
)
self.repetition_processor = (
HeterogeneousRepetitionPenaltyLogitsProcessor(
repetition_penalty, dtype, device
)
if any([x != 1.0 for x in repetition_penalty])
else None
)
self.frequency_processor = (
HeterogeneousFrequencyPenaltyLogitsProcessor(
frequency_penalty, dtype, device
)
if any([x != 0.0 for x in frequency_penalty])
else None
)
self.grammar_processor = (
HeterogeneousGrammarLogitProcessor(
tokenizer, device, grammars, grammar_types
)
if any([grammar != "" for grammar in grammars])
else None
)
if any(x != 1.0 for x in temperature):
do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
]
warpers.append(
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
)
if any(x != 0 for x in top_k):
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
if any(x < 1.0 for x in top_p):
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
if any(x < 1.0 for x in typical_p):
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
self.warpers = warpers
if any(do_sample):
self.choice = HeterogeneousSampling(do_sample, seeds, device)
else:
self.choice = Greedy()
self.seeds = seeds
self.do_sample = do_sample
self.dtype = dtype
self.device = device
self.tokenizer = tokenizer
self.fsm_grammar_states = fsm_grammar_states
self.grammars = grammars
self.grammar_types = grammar_types
2023-12-11 06:49:52 -07:00
def __call__(
self,
input_ids: torch.Tensor,
scores: torch.Tensor,
speculate: int,
speculated_ids: Optional[torch.Tensor] = None,
speculative_scores: Optional[torch.Tensor] = None,
verbose=False,
):
2023-12-11 04:46:30 -07:00
if speculated_ids is not None:
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1
scores = scores.view(B, S, -1)
else:
B = scores.shape[0]
S = 1
scores = scores.view(B, S, -1)
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
2023-12-11 04:46:30 -07:00
for j in range(S):
_scores = scores[:, j]
if self.watermark_processor is not None:
_scores = self.watermark_processor(input_ids, _scores)
if self.repetition_processor is not None:
_scores = self.repetition_processor(input_ids, _scores)
if self.frequency_processor is not None:
_scores = self.frequency_processor(input_ids, _scores)
if self.grammar_processor is not None:
2024-02-16 09:50:57 -07:00
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
for warper in self.warpers:
_scores = warper(input_ids, _scores)
2023-12-11 04:46:30 -07:00
_next_ids = self.choice(_scores)
scores[:, j] = _scores
next_ids[:, j] = _next_ids
2023-12-11 06:49:52 -07:00
next_ids = next_ids.view(B * S)
Fixing top_n_tokens. (#1497) # What does this PR do? Superseeds #1459 The fix works as follows. We updated next_token_chooser to return all logprbs, then batch_top_n_tokens, now also gets accepted_ids + speculated_length (so it knows how to interpret the flat logprobs). We then update the code to return lists ot `Tokens` that it expects. <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2024-01-26 12:13:47 -07:00
allscores = scores.view(B * S, -1)
alllogprobs = torch.log_softmax(allscores, -1)
2023-12-11 04:46:30 -07:00
if speculated_ids is not None:
accepted_ids = []
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1
indices = []
for i in range(B):
2023-12-11 06:49:52 -07:00
_next_ids = next_ids[i * S : (i + 1) * S]
2023-12-11 04:46:30 -07:00
_speculated_ids = speculated_ids[i]
validate_speculative = _next_ids[:-1] == _speculated_ids
index = i * S
accepted = 1
# First is always valid
indices.append(index)
for valid in validate_speculative.tolist():
if valid:
index += 1
accepted += 1
indices.append(index)
else:
break
accepted_ids.append(accepted)
2023-12-11 06:49:52 -07:00
accepted_ids = torch.tensor(
accepted_ids, device=input_ids.device, dtype=input_ids.dtype
)
2023-12-11 04:46:30 -07:00
next_ids = next_ids[indices]
Fixing top_n_tokens. (#1497) # What does this PR do? Superseeds #1459 The fix works as follows. We updated next_token_chooser to return all logprbs, then batch_top_n_tokens, now also gets accepted_ids + speculated_length (so it knows how to interpret the flat logprobs). We then update the code to return lists ot `Tokens` that it expects. <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2024-01-26 12:13:47 -07:00
logprobs = alllogprobs[indices]
2023-12-11 04:46:30 -07:00
indices = torch.arange(B, device=input_ids.device) * S
if speculative_scores is not None:
speculative_scores = speculative_scores[indices + accepted_ids - 1]
else:
accepted_ids = torch.ones_like(next_ids)
Fixing top_n_tokens. (#1497) # What does this PR do? Superseeds #1459 The fix works as follows. We updated next_token_chooser to return all logprbs, then batch_top_n_tokens, now also gets accepted_ids + speculated_length (so it knows how to interpret the flat logprobs). We then update the code to return lists ot `Tokens` that it expects. <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2024-01-26 12:13:47 -07:00
logprobs = alllogprobs
Rebased #617 (#868) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Vincent Brouwers <vincent.brouwers@ing.com>
2023-08-28 03:43:47 -06:00
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
2023-12-11 04:46:30 -07:00
if speculate > 0:
if speculative_scores is not None:
# Medusa provided some scores
speculative_ids = Greedy()(speculative_scores)
else:
# n-gram
2023-12-11 06:49:52 -07:00
speculative_ids = create_n_gram_speculation(
input_ids, next_ids, accepted_ids, speculate, verbose
)
2023-12-11 04:46:30 -07:00
else:
speculative_ids = None
Fixing top_n_tokens. (#1497) # What does this PR do? Superseeds #1459 The fix works as follows. We updated next_token_chooser to return all logprbs, then batch_top_n_tokens, now also gets accepted_ids + speculated_length (so it knows how to interpret the flat logprobs). We then update the code to return lists ot `Tokens` that it expects. <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2024-01-26 12:13:47 -07:00
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
def advance_grammar(self, next_ids: List[int]):
if self.grammar_processor is not None:
other_new_states = self.grammar_processor.advance_batch(
2024-02-16 09:50:57 -07:00
next_ids, self.fsm_grammar_states
)
self.fsm_grammar_states = other_new_states
return self
def advance_grammar_single(self, grammar_state_index: int, next_id: int):
if self.grammar_processor is not None:
2024-02-16 03:58:58 -07:00
self.fsm_grammar_states[grammar_state_index] = (
self.grammar_processor.advance_at_index(
next_id,
self.fsm_grammar_states[grammar_state_index],
grammar_state_index,
)
)
return self
def filter(self, indices):
if self.watermark_processor is not None:
self.watermark_processor = self.watermark_processor.filter(indices)
if self.repetition_processor is not None:
self.repetition_processor = self.repetition_processor.filter(indices)
if self.frequency_processor is not None:
self.frequency_processor = self.frequency_processor.filter(indices)
if self.grammar_processor is not None:
self.grammar_processor = self.grammar_processor.filter(indices)
filtered_warpers = []
for warper in self.warpers:
filtered_warper = warper.filter(indices)
if filtered_warper is not None:
filtered_warpers.append(filtered_warper)
self.warpers = filtered_warpers
self.seeds = [self.seeds[i] for i in indices]
self.do_sample = [self.do_sample[i] for i in indices]
new_grammars = []
new_fsm_grammar_states = []
new_grammar_types = []
for i in indices:
new_grammars.append(self.grammars[i])
new_fsm_grammar_states.append(self.fsm_grammar_states[i])
new_grammar_types.append(self.grammar_types[i])
self.grammars = new_grammars
self.fsm_grammar_states = new_fsm_grammar_states
self.grammar_types = new_grammar_types
if any(self.do_sample):
self.choice.filter(indices)
else:
self.choice = Greedy()
return self
@classmethod
def from_pb(
cls,
pb: List[generate_pb2.NextTokenChooserParameters],
dtype: torch.dtype,
device: torch.device,
tokenizer: PreTrainedTokenizerBase,
fsm_grammar_states: Optional[List[int]] = None,
) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb],
temperature=[pb_.temperature for pb_ in pb],
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
top_k=[pb_.top_k for pb_ in pb],
top_p=[pb_.top_p for pb_ in pb],
typical_p=[pb_.typical_p for pb_ in pb],
do_sample=[pb_.do_sample for pb_ in pb],
seeds=[pb_.seed for pb_ in pb],
device=device,
dtype=dtype,
tokenizer=tokenizer,
grammars=[pb_.grammar for pb_ in pb],
grammar_types=[pb_.grammar_type for pb_ in pb],
fsm_grammar_states=(
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
),
)
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, -1)
# Avoid GPU<->CPU sync done by torch multinomial
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
return probs.div_(q).argmax()
class Greedy:
def __call__(self, logits):
return logits.argmax(dim=-1)
class HeterogeneousSampling:
r"""
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
"""
def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
self.seeds = seeds
self.greedy_indices = []
self.sampling_mapping = {}
for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
if sample:
self.sampling_mapping[i] = Sampling(seed, device)
else:
self.greedy_indices.append(i)
self.greedy = Greedy()
def __call__(self, logits):
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
if self.greedy_indices:
# Computing for all indices is faster than slicing
torch.argmax(logits, -1, out=out)
for i, sampling in self.sampling_mapping.items():
out[i] = sampling(logits[i])
return out
def filter(self, indices):
new_greedy_indices = []
new_sampling_mapping = {}
for i, idx in enumerate(indices):
if idx in self.sampling_mapping:
new_sampling_mapping[i] = self.sampling_mapping[idx]
else:
new_greedy_indices.append(i)
self.greedy_indices = new_greedy_indices
self.sampling_mapping = new_sampling_mapping
return self
Rebased #617 (#868) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Vincent Brouwers <vincent.brouwers@ing.com>
2023-08-28 03:43:47 -06:00
def batch_top_tokens(
top_n_tokens: List[int],
top_n_tokens_tensor: torch.Tensor,
logprobs: torch.Tensor,
accepted_ids: torch.Tensor,
Fixing top_n_tokens. (#1497) # What does this PR do? Superseeds #1459 The fix works as follows. We updated next_token_chooser to return all logprbs, then batch_top_n_tokens, now also gets accepted_ids + speculated_length (so it knows how to interpret the flat logprobs). We then update the code to return lists ot `Tokens` that it expects. <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2024-01-26 12:13:47 -07:00
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
Rebased #617 (#868) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Vincent Brouwers <vincent.brouwers@ing.com>
2023-08-28 03:43:47 -06:00
"""Find the top n most likely tokens for a batch of generations.
When multiple tokens have equal probabilities and they don't all fit, the
remaining tokens are also returned.
"""
max_top_n = max(top_n_tokens)
# Early exit when top_n_tokens is not used
if max_top_n == 0:
Fixing top_n_tokens. (#1497) # What does this PR do? Superseeds #1459 The fix works as follows. We updated next_token_chooser to return all logprbs, then batch_top_n_tokens, now also gets accepted_ids + speculated_length (so it knows how to interpret the flat logprobs). We then update the code to return lists ot `Tokens` that it expects. <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2024-01-26 12:13:47 -07:00
return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
batch_size = accepted_ids.shape[0]
speculate_size = logprobs.shape[0] // batch_size
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
Rebased #617 (#868) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Vincent Brouwers <vincent.brouwers@ing.com>
2023-08-28 03:43:47 -06:00
# Ensure top_n doesn't exceed vocab size
top_n_tokens = [
min(tok, logprobs.size(-1))
for tok in top_n_tokens
for _ in range(speculate_size)
]
Rebased #617 (#868) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Vincent Brouwers <vincent.brouwers@ing.com>
2023-08-28 03:43:47 -06:00
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset
Fixing top_n_tokens. (#1497) # What does this PR do? Superseeds #1459 The fix works as follows. We updated next_token_chooser to return all logprbs, then batch_top_n_tokens, now also gets accepted_ids + speculated_length (so it knows how to interpret the flat logprobs). We then update the code to return lists ot `Tokens` that it expects. <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2024-01-26 12:13:47 -07:00
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values
Rebased #617 (#868) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Vincent Brouwers <vincent.brouwers@ing.com>
2023-08-28 03:43:47 -06:00
nth_highest = torch.gather(
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
)
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
# Find the new "fuzzy" top n values
top_n_indices = (logprobs >= nth_highest).nonzero()
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
2023-09-27 04:22:09 -06:00
Fixing top_k tokens when k ends up < 0 (#966) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2023-08-31 16:22:03 -06:00
k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max()
Rebased #617 (#868) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Vincent Brouwers <vincent.brouwers@ing.com>
2023-08-28 03:43:47 -06:00
# Take a new topk for these new max n values
Fixing top_k tokens when k ends up < 0 (#966) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2023-08-31 16:22:03 -06:00
top_k = torch.topk(logprobs, k=k, dim=1, sorted=True)
Rebased #617 (#868) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Vincent Brouwers <vincent.brouwers@ing.com>
2023-08-28 03:43:47 -06:00
top_n_ishes = top_n_ishes.tolist()
top_indices = top_k.indices.tolist()
top_values = top_k.values.tolist()
Fixing top_n_tokens. (#1497) # What does this PR do? Superseeds #1459 The fix works as follows. We updated next_token_chooser to return all logprbs, then batch_top_n_tokens, now also gets accepted_ids + speculated_length (so it knows how to interpret the flat logprobs). We then update the code to return lists ot `Tokens` that it expects. <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2024-01-26 12:13:47 -07:00
batch_top_token_ids = []
batch_top_token_logprobs = []
accepted_ids_list = accepted_ids.tolist()
for i, n_accepted_ids in enumerate(accepted_ids_list):
start = speculate_size * i
stop = speculate_size * (i + 1)
_top_indices = top_indices[start:stop]
_top_values = top_values[start:stop]
_top_n_ishes = top_n_ishes[start:stop]
_top_n_tokens = top_n_tokens[start:stop]
Fixing top_n_tokens. (#1497) # What does this PR do? Superseeds #1459 The fix works as follows. We updated next_token_chooser to return all logprbs, then batch_top_n_tokens, now also gets accepted_ids + speculated_length (so it knows how to interpret the flat logprobs). We then update the code to return lists ot `Tokens` that it expects. <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2024-01-26 12:13:47 -07:00
_top_indices = _top_indices[:n_accepted_ids]
_top_values = _top_values[:n_accepted_ids]
_top_n_ishes = _top_n_ishes[:n_accepted_ids]
_top_n_tokens = _top_n_tokens[:n_accepted_ids]
row_top_token_ids = []
row_top_token_logprobs = []
for idxs, vals, n, req_n in zip(
_top_indices, _top_values, _top_n_ishes, _top_n_tokens
):
Fixing top_n_tokens. (#1497) # What does this PR do? Superseeds #1459 The fix works as follows. We updated next_token_chooser to return all logprbs, then batch_top_n_tokens, now also gets accepted_ids + speculated_length (so it knows how to interpret the flat logprobs). We then update the code to return lists ot `Tokens` that it expects. <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2024-01-26 12:13:47 -07:00
indices = idxs[:n] if req_n > 0 else []
values = vals[:n] if req_n > 0 else []
row_top_token_ids.append(indices)
row_top_token_logprobs.append(values)
batch_top_token_ids.append(row_top_token_ids)
batch_top_token_logprobs.append(row_top_token_logprobs)
return batch_top_token_ids, batch_top_token_logprobs