2023-02-14 05:02:16 -07:00
|
|
|
import re
|
2024-02-08 10:41:25 -07:00
|
|
|
from typing import List, Optional, Tuple
|
2023-02-14 05:02:16 -07:00
|
|
|
|
2023-08-28 03:43:47 -06:00
|
|
|
import torch
|
2023-03-07 10:52:22 -07:00
|
|
|
from text_generation_server.pb import generate_pb2
|
|
|
|
from text_generation_server.pb.generate_pb2 import FinishReason
|
2023-05-26 04:30:27 -06:00
|
|
|
from text_generation_server.utils.logits_process import (
|
2024-02-08 10:41:25 -07:00
|
|
|
FrequencyPenaltyLogitsProcessor,
|
2023-08-28 03:43:47 -06:00
|
|
|
HeterogeneousProcessorWrapper,
|
2023-05-26 04:30:27 -06:00
|
|
|
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
2024-02-08 10:41:25 -07:00
|
|
|
HeterogeneousFrequencyPenaltyLogitsProcessor,
|
2023-05-26 04:30:27 -06:00
|
|
|
HeterogeneousTemperatureLogitsWarper,
|
|
|
|
HeterogeneousTopKLogitsWarper,
|
|
|
|
HeterogeneousTopPLogitsWarper,
|
|
|
|
HeterogeneousTypicalLogitsWarper,
|
2023-08-28 03:43:47 -06:00
|
|
|
static_warper,
|
2023-05-26 04:30:27 -06:00
|
|
|
)
|
2023-08-28 03:43:47 -06:00
|
|
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
|
|
|
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
2023-05-10 11:08:54 -06:00
|
|
|
|
2023-12-11 06:49:52 -07:00
|
|
|
|
2023-05-10 11:08:54 -06:00
|
|
|
class NextTokenChooser:
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
watermark=False,
|
|
|
|
temperature=1.0,
|
|
|
|
repetition_penalty=1.0,
|
2024-02-08 10:41:25 -07:00
|
|
|
frequency_penalty=0.0,
|
2023-05-10 11:08:54 -06:00
|
|
|
top_k=None,
|
|
|
|
top_p=None,
|
|
|
|
typical_p=None,
|
|
|
|
do_sample=False,
|
|
|
|
seed=0,
|
|
|
|
device="cpu",
|
|
|
|
):
|
|
|
|
self.watermark_processor = (
|
|
|
|
WatermarkLogitsProcessor(device=device) if watermark else None
|
|
|
|
)
|
|
|
|
self.repetition_processor = (
|
|
|
|
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
|
2024-02-08 10:41:25 -07:00
|
|
|
if repetition_penalty and repetition_penalty != 1.0
|
|
|
|
else None
|
|
|
|
)
|
|
|
|
self.frequency_processor = (
|
|
|
|
FrequencyPenaltyLogitsProcessor(penalty=frequency_penalty)
|
|
|
|
if frequency_penalty and frequency_penalty != 0.0
|
2023-05-10 11:08:54 -06:00
|
|
|
else None
|
|
|
|
)
|
|
|
|
|
|
|
|
has_warpers = (
|
|
|
|
(temperature is not None and temperature != 1.0)
|
|
|
|
or (top_k is not None and top_k != 0)
|
|
|
|
or (top_p is not None and top_p < 1.0)
|
|
|
|
or (typical_p is not None and typical_p < 1.0)
|
|
|
|
)
|
|
|
|
if has_warpers:
|
|
|
|
self.static_warper = static_warper(
|
|
|
|
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
self.static_warper = None
|
|
|
|
|
|
|
|
sampling = do_sample or has_warpers
|
2023-02-14 05:02:16 -07:00
|
|
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
|
|
|
|
|
|
|
def __call__(self, input_ids, scores):
|
2023-05-26 04:30:27 -06:00
|
|
|
if self.watermark_processor is not None:
|
2023-05-10 11:08:54 -06:00
|
|
|
scores = self.watermark_processor(input_ids, scores)
|
2023-05-26 04:30:27 -06:00
|
|
|
if self.repetition_processor is not None:
|
2023-05-10 11:08:54 -06:00
|
|
|
scores = self.repetition_processor(input_ids, scores)
|
2024-02-08 10:41:25 -07:00
|
|
|
if self.frequency_processor is not None:
|
|
|
|
scores = self.frequency_processor(input_ids, scores)
|
2023-02-14 05:02:16 -07:00
|
|
|
|
2023-05-10 11:08:54 -06:00
|
|
|
if self.static_warper is None:
|
|
|
|
next_logprob = torch.log_softmax(scores, -1)
|
|
|
|
else:
|
|
|
|
scores, next_logprob = self.static_warper(scores)
|
2023-02-14 05:02:16 -07:00
|
|
|
|
2023-05-10 11:08:54 -06:00
|
|
|
next_id = self.choice(scores[-1]).view(1, 1)
|
2023-02-14 05:02:16 -07:00
|
|
|
|
2023-05-10 11:08:54 -06:00
|
|
|
return next_id, next_logprob
|
2023-02-14 05:02:16 -07:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_pb(
|
2023-03-02 04:30:41 -07:00
|
|
|
cls,
|
|
|
|
pb: generate_pb2.NextTokenChooserParameters,
|
|
|
|
device: torch.device,
|
2023-02-14 05:02:16 -07:00
|
|
|
) -> "NextTokenChooser":
|
|
|
|
return NextTokenChooser(
|
2023-03-02 04:30:41 -07:00
|
|
|
watermark=pb.watermark,
|
2023-02-14 05:02:16 -07:00
|
|
|
temperature=pb.temperature,
|
|
|
|
repetition_penalty=pb.repetition_penalty,
|
2024-02-08 10:41:25 -07:00
|
|
|
frequency_penalty=pb.frequency_penalty,
|
2023-02-14 05:02:16 -07:00
|
|
|
top_k=pb.top_k,
|
|
|
|
top_p=pb.top_p,
|
2023-03-09 03:33:57 -07:00
|
|
|
typical_p=pb.typical_p,
|
2023-02-14 05:02:16 -07:00
|
|
|
do_sample=pb.do_sample,
|
|
|
|
seed=pb.seed,
|
|
|
|
device=device,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class StopSequenceCriteria:
|
|
|
|
def __init__(self, stop_sequence: str):
|
2023-04-05 11:37:41 -06:00
|
|
|
stop_sequence = re.escape(stop_sequence)
|
2023-12-14 07:59:38 -07:00
|
|
|
self.regex = re.compile(f"{stop_sequence}$")
|
2023-02-14 05:02:16 -07:00
|
|
|
|
|
|
|
def __call__(self, output: str) -> bool:
|
|
|
|
if self.regex.findall(output):
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
class StoppingCriteria:
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
eos_token_id: int,
|
|
|
|
stop_sequence_criterias: List[StopSequenceCriteria],
|
2023-03-30 07:26:27 -06:00
|
|
|
max_new_tokens: int = 20,
|
|
|
|
ignore_eos_token: bool = False,
|
2023-02-14 05:02:16 -07:00
|
|
|
):
|
|
|
|
self.eos_token_id = eos_token_id
|
|
|
|
self.stop_sequence_criterias = stop_sequence_criterias
|
|
|
|
self.max_new_tokens = max_new_tokens
|
|
|
|
self.current_tokens = 0
|
2023-12-14 09:04:58 -07:00
|
|
|
self.current_output = ""
|
2023-03-30 07:26:27 -06:00
|
|
|
self.ignore_eos_token = ignore_eos_token
|
2023-02-14 05:02:16 -07:00
|
|
|
|
|
|
|
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
|
|
|
self.current_tokens += 1
|
|
|
|
if self.current_tokens >= self.max_new_tokens:
|
|
|
|
return True, FinishReason.FINISH_REASON_LENGTH
|
|
|
|
|
2023-03-30 07:26:27 -06:00
|
|
|
if not self.ignore_eos_token and last_token == self.eos_token_id:
|
2023-02-14 05:02:16 -07:00
|
|
|
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
|
|
|
|
2023-12-14 09:04:58 -07:00
|
|
|
if self.stop_sequence_criterias:
|
|
|
|
self.current_output += last_output
|
|
|
|
# There is no need to keep an output that is too long
|
|
|
|
if len(self.current_output) > 300:
|
|
|
|
# Slice to -200 to avoid doing it all the time
|
|
|
|
self.current_output = self.current_output[-200:]
|
|
|
|
for stop_sequence_criteria in self.stop_sequence_criterias:
|
|
|
|
if stop_sequence_criteria(self.current_output):
|
|
|
|
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
|
2023-02-14 05:02:16 -07:00
|
|
|
|
|
|
|
return False, None
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_pb(
|
|
|
|
cls,
|
|
|
|
pb: generate_pb2.StoppingCriteriaParameters,
|
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
|
|
) -> "StoppingCriteria":
|
|
|
|
stop_sequence_criterias = [
|
|
|
|
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
|
|
|
]
|
|
|
|
return StoppingCriteria(
|
2023-03-30 07:26:27 -06:00
|
|
|
tokenizer.eos_token_id,
|
|
|
|
stop_sequence_criterias,
|
|
|
|
pb.max_new_tokens,
|
|
|
|
pb.ignore_eos_token,
|
2023-02-14 05:02:16 -07:00
|
|
|
)
|
2023-05-26 04:30:27 -06:00
|
|
|
|
2023-12-11 06:49:52 -07:00
|
|
|
|
|
|
|
def create_n_gram_speculation(
|
|
|
|
input_ids: torch.Tensor,
|
|
|
|
next_ids: torch.Tensor,
|
|
|
|
accepted_ids: torch.Tensor,
|
|
|
|
speculate: int,
|
|
|
|
verbose: bool,
|
|
|
|
):
|
2023-12-11 04:46:30 -07:00
|
|
|
# Very trivial approach, find first match in the string.
|
|
|
|
# This is much less refined than actual n-gram but seems to work
|
|
|
|
# relatively OK in grounded mode and is by far much faster with
|
|
|
|
# much less worst case complexity as everything happens on device.
|
|
|
|
B = accepted_ids.shape[0]
|
|
|
|
device = input_ids.device
|
2023-12-11 06:49:52 -07:00
|
|
|
seeds = next_ids[accepted_ids.cumsum(dim=-1) - 1]
|
2023-12-11 04:46:30 -07:00
|
|
|
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
|
2023-12-11 06:49:52 -07:00
|
|
|
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(
|
|
|
|
speculate, device=device
|
|
|
|
)
|
2023-12-11 04:46:30 -07:00
|
|
|
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
|
|
|
|
|
|
|
|
speculative_ids = input_ids.gather(dim=-1, index=all_indices)
|
|
|
|
return speculative_ids
|
2023-05-26 04:30:27 -06:00
|
|
|
|
2023-12-11 06:49:52 -07:00
|
|
|
|
2023-05-26 04:30:27 -06:00
|
|
|
class HeterogeneousNextTokenChooser:
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
dtype: torch.dtype,
|
|
|
|
device: torch.device,
|
|
|
|
watermark: List[bool],
|
|
|
|
temperature: List[float],
|
|
|
|
repetition_penalty: List[float],
|
2024-02-08 10:41:25 -07:00
|
|
|
frequency_penalty: List[float],
|
2023-05-26 04:30:27 -06:00
|
|
|
top_k: List[int],
|
|
|
|
top_p: List[float],
|
|
|
|
typical_p: List[float],
|
|
|
|
do_sample: List[bool],
|
|
|
|
seeds: List[int],
|
|
|
|
):
|
|
|
|
warpers = []
|
|
|
|
|
|
|
|
self.watermark_processor = (
|
|
|
|
HeterogeneousProcessorWrapper(
|
|
|
|
{
|
|
|
|
i: WatermarkLogitsProcessor(device=device)
|
|
|
|
for i, do_watermark in enumerate(watermark)
|
|
|
|
if do_watermark
|
|
|
|
}
|
|
|
|
)
|
|
|
|
if any(watermark)
|
|
|
|
else None
|
|
|
|
)
|
|
|
|
|
|
|
|
self.repetition_processor = (
|
|
|
|
HeterogeneousRepetitionPenaltyLogitsProcessor(
|
|
|
|
repetition_penalty, dtype, device
|
|
|
|
)
|
|
|
|
if any([x != 1.0 for x in repetition_penalty])
|
|
|
|
else None
|
|
|
|
)
|
|
|
|
|
2024-02-08 10:41:25 -07:00
|
|
|
self.frequency_processor = (
|
|
|
|
HeterogeneousFrequencyPenaltyLogitsProcessor(
|
|
|
|
frequency_penalty, dtype, device
|
|
|
|
)
|
|
|
|
if any([x != 0.0 for x in frequency_penalty])
|
|
|
|
else None
|
|
|
|
)
|
|
|
|
|
2023-05-26 04:30:27 -06:00
|
|
|
if any([x != 1.0 for x in temperature]):
|
|
|
|
do_sample = [
|
|
|
|
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
|
|
|
]
|
|
|
|
warpers.append(
|
|
|
|
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
|
|
|
|
)
|
|
|
|
|
|
|
|
if any([x != 0 for x in top_k]):
|
|
|
|
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
|
|
|
|
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
|
|
|
|
|
|
|
|
if any([x < 1.0 for x in top_p]):
|
|
|
|
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
|
|
|
|
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
|
|
|
|
|
|
|
|
if any([x < 1.0 for x in typical_p]):
|
|
|
|
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
|
|
|
|
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
|
|
|
|
|
|
|
|
self.warpers = warpers
|
|
|
|
|
|
|
|
if any(do_sample):
|
|
|
|
self.choice = HeterogeneousSampling(do_sample, seeds, device)
|
|
|
|
else:
|
|
|
|
self.choice = Greedy()
|
|
|
|
|
|
|
|
self.seeds = seeds
|
|
|
|
self.do_sample = do_sample
|
2023-06-30 11:09:59 -06:00
|
|
|
self.dtype = dtype
|
|
|
|
self.device = device
|
2023-05-26 04:30:27 -06:00
|
|
|
|
2023-12-11 06:49:52 -07:00
|
|
|
def __call__(
|
|
|
|
self,
|
|
|
|
input_ids: torch.Tensor,
|
|
|
|
scores: torch.Tensor,
|
|
|
|
speculate: int,
|
|
|
|
speculated_ids: Optional[torch.Tensor] = None,
|
|
|
|
speculative_scores: Optional[torch.Tensor] = None,
|
|
|
|
verbose=False,
|
|
|
|
):
|
2023-12-11 04:46:30 -07:00
|
|
|
if speculated_ids is not None:
|
|
|
|
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
|
|
|
|
S = speculated_ids.shape[1] + 1
|
|
|
|
scores = scores.view(B, S, -1)
|
|
|
|
else:
|
|
|
|
B = scores.shape[0]
|
|
|
|
S = 1
|
|
|
|
scores = scores.view(B, S, -1)
|
|
|
|
|
|
|
|
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
|
|
|
|
for j in range(S):
|
|
|
|
_scores = scores[:, j]
|
|
|
|
if self.watermark_processor is not None:
|
|
|
|
_scores = self.watermark_processor(input_ids, _scores)
|
|
|
|
if self.repetition_processor is not None:
|
|
|
|
_scores = self.repetition_processor(input_ids, _scores)
|
2024-02-08 10:41:25 -07:00
|
|
|
if self.frequency_processor is not None:
|
|
|
|
_scores = self.frequency_processor(input_ids, _scores)
|
2023-12-11 04:46:30 -07:00
|
|
|
|
|
|
|
for warper in self.warpers:
|
|
|
|
_scores = warper(input_ids, _scores)
|
|
|
|
|
|
|
|
_next_ids = self.choice(_scores)
|
|
|
|
scores[:, j] = _scores
|
|
|
|
next_ids[:, j] = _next_ids
|
2023-12-11 06:49:52 -07:00
|
|
|
next_ids = next_ids.view(B * S)
|
2024-01-26 12:13:47 -07:00
|
|
|
allscores = scores.view(B * S, -1)
|
|
|
|
alllogprobs = torch.log_softmax(allscores, -1)
|
2023-12-11 04:46:30 -07:00
|
|
|
|
|
|
|
if speculated_ids is not None:
|
|
|
|
accepted_ids = []
|
|
|
|
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
|
|
|
|
S = speculated_ids.shape[1] + 1
|
|
|
|
indices = []
|
|
|
|
for i in range(B):
|
2023-12-11 06:49:52 -07:00
|
|
|
_next_ids = next_ids[i * S : (i + 1) * S]
|
2023-12-11 04:46:30 -07:00
|
|
|
_speculated_ids = speculated_ids[i]
|
|
|
|
validate_speculative = _next_ids[:-1] == _speculated_ids
|
|
|
|
index = i * S
|
|
|
|
accepted = 1
|
|
|
|
# First is always valid
|
|
|
|
indices.append(index)
|
|
|
|
for valid in validate_speculative.tolist():
|
|
|
|
if valid:
|
|
|
|
index += 1
|
|
|
|
accepted += 1
|
|
|
|
indices.append(index)
|
|
|
|
else:
|
|
|
|
break
|
|
|
|
accepted_ids.append(accepted)
|
|
|
|
|
2023-12-11 06:49:52 -07:00
|
|
|
accepted_ids = torch.tensor(
|
|
|
|
accepted_ids, device=input_ids.device, dtype=input_ids.dtype
|
|
|
|
)
|
2023-12-11 04:46:30 -07:00
|
|
|
next_ids = next_ids[indices]
|
2024-01-26 12:13:47 -07:00
|
|
|
logprobs = alllogprobs[indices]
|
2023-12-11 04:46:30 -07:00
|
|
|
indices = torch.arange(B, device=input_ids.device) * S
|
|
|
|
if speculative_scores is not None:
|
|
|
|
speculative_scores = speculative_scores[indices + accepted_ids - 1]
|
|
|
|
else:
|
|
|
|
accepted_ids = torch.ones_like(next_ids)
|
2024-01-26 12:13:47 -07:00
|
|
|
logprobs = alllogprobs
|
2023-05-26 04:30:27 -06:00
|
|
|
|
2023-08-28 03:43:47 -06:00
|
|
|
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
2023-05-26 04:30:27 -06:00
|
|
|
|
2023-12-11 04:46:30 -07:00
|
|
|
if speculate > 0:
|
|
|
|
if speculative_scores is not None:
|
|
|
|
# Medusa provided some scores
|
|
|
|
speculative_ids = Greedy()(speculative_scores)
|
|
|
|
else:
|
|
|
|
# n-gram
|
2023-12-11 06:49:52 -07:00
|
|
|
speculative_ids = create_n_gram_speculation(
|
|
|
|
input_ids, next_ids, accepted_ids, speculate, verbose
|
|
|
|
)
|
2023-12-11 04:46:30 -07:00
|
|
|
else:
|
|
|
|
speculative_ids = None
|
|
|
|
|
2024-01-26 12:13:47 -07:00
|
|
|
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
|
2023-05-26 04:30:27 -06:00
|
|
|
|
|
|
|
def filter(self, indices):
|
|
|
|
if self.watermark_processor is not None:
|
|
|
|
self.watermark_processor = self.watermark_processor.filter(indices)
|
|
|
|
|
|
|
|
if self.repetition_processor is not None:
|
|
|
|
self.repetition_processor = self.repetition_processor.filter(indices)
|
|
|
|
|
2024-02-08 10:41:25 -07:00
|
|
|
if self.frequency_processor is not None:
|
|
|
|
self.frequency_processor = self.frequency_processor.filter(indices)
|
|
|
|
|
2023-05-26 04:30:27 -06:00
|
|
|
filtered_warpers = []
|
|
|
|
for warper in self.warpers:
|
|
|
|
filtered_warper = warper.filter(indices)
|
|
|
|
if filtered_warper is not None:
|
|
|
|
filtered_warpers.append(filtered_warper)
|
|
|
|
self.warpers = filtered_warpers
|
|
|
|
|
|
|
|
self.seeds = [self.seeds[i] for i in indices]
|
|
|
|
self.do_sample = [self.do_sample[i] for i in indices]
|
|
|
|
|
|
|
|
if any(self.do_sample):
|
|
|
|
self.choice.filter(indices)
|
|
|
|
else:
|
|
|
|
self.choice = Greedy()
|
|
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_pb(
|
|
|
|
cls,
|
|
|
|
pb: List[generate_pb2.NextTokenChooserParameters],
|
|
|
|
dtype: torch.dtype,
|
|
|
|
device: torch.device,
|
|
|
|
) -> "HeterogeneousNextTokenChooser":
|
|
|
|
return HeterogeneousNextTokenChooser(
|
|
|
|
watermark=[pb_.watermark for pb_ in pb],
|
|
|
|
temperature=[pb_.temperature for pb_ in pb],
|
|
|
|
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
|
2024-02-08 10:41:25 -07:00
|
|
|
frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
|
2023-05-26 04:30:27 -06:00
|
|
|
top_k=[pb_.top_k for pb_ in pb],
|
|
|
|
top_p=[pb_.top_p for pb_ in pb],
|
|
|
|
typical_p=[pb_.typical_p for pb_ in pb],
|
|
|
|
do_sample=[pb_.do_sample for pb_ in pb],
|
|
|
|
seeds=[pb_.seed for pb_ in pb],
|
|
|
|
device=device,
|
|
|
|
dtype=dtype,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class Sampling:
|
|
|
|
def __init__(self, seed: int, device: str = "cpu"):
|
|
|
|
self.generator = torch.Generator(device)
|
|
|
|
self.generator.manual_seed(seed)
|
|
|
|
self.seed = seed
|
|
|
|
|
|
|
|
def __call__(self, logits):
|
|
|
|
probs = torch.nn.functional.softmax(logits, -1)
|
|
|
|
# Avoid GPU<->CPU sync done by torch multinomial
|
|
|
|
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
|
|
|
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
|
|
|
|
return probs.div_(q).argmax()
|
|
|
|
|
|
|
|
|
|
|
|
class Greedy:
|
|
|
|
def __call__(self, logits):
|
|
|
|
return logits.argmax(dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
class HeterogeneousSampling:
|
|
|
|
r"""
|
|
|
|
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
|
|
|
|
self.seeds = seeds
|
|
|
|
|
|
|
|
self.greedy_indices = []
|
|
|
|
self.sampling_mapping = {}
|
|
|
|
for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
|
|
|
|
if sample:
|
|
|
|
self.sampling_mapping[i] = Sampling(seed, device)
|
|
|
|
else:
|
|
|
|
self.greedy_indices.append(i)
|
|
|
|
|
|
|
|
self.greedy = Greedy()
|
|
|
|
|
|
|
|
def __call__(self, logits):
|
|
|
|
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
|
|
|
|
if self.greedy_indices:
|
|
|
|
# Computing for all indices is faster than slicing
|
|
|
|
torch.argmax(logits, -1, out=out)
|
|
|
|
|
|
|
|
for i, sampling in self.sampling_mapping.items():
|
|
|
|
out[i] = sampling(logits[i])
|
|
|
|
return out
|
|
|
|
|
|
|
|
def filter(self, indices):
|
|
|
|
new_greedy_indices = []
|
|
|
|
new_sampling_mapping = {}
|
|
|
|
for i, idx in enumerate(indices):
|
|
|
|
if idx in self.sampling_mapping:
|
|
|
|
new_sampling_mapping[i] = self.sampling_mapping[idx]
|
|
|
|
else:
|
|
|
|
new_greedy_indices.append(i)
|
|
|
|
|
|
|
|
self.greedy_indices = new_greedy_indices
|
|
|
|
self.sampling_mapping = new_sampling_mapping
|
|
|
|
return self
|
2023-08-28 03:43:47 -06:00
|
|
|
|
|
|
|
|
|
|
|
def batch_top_tokens(
|
2024-02-08 10:41:25 -07:00
|
|
|
top_n_tokens: List[int],
|
|
|
|
top_n_tokens_tensor: torch.Tensor,
|
|
|
|
logprobs: torch.Tensor,
|
|
|
|
accepted_ids: torch.Tensor,
|
2024-01-26 12:13:47 -07:00
|
|
|
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
|
2023-08-28 03:43:47 -06:00
|
|
|
"""Find the top n most likely tokens for a batch of generations.
|
|
|
|
|
|
|
|
When multiple tokens have equal probabilities and they don't all fit, the
|
|
|
|
remaining tokens are also returned.
|
|
|
|
"""
|
|
|
|
max_top_n = max(top_n_tokens)
|
|
|
|
# Early exit when top_n_tokens is not used
|
|
|
|
if max_top_n == 0:
|
2024-01-26 12:13:47 -07:00
|
|
|
return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
|
|
|
|
|
|
|
|
batch_size = accepted_ids.shape[0]
|
|
|
|
speculate_size = logprobs.shape[0] // batch_size
|
|
|
|
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
|
2023-08-28 03:43:47 -06:00
|
|
|
# Ensure top_n doesn't exceed vocab size
|
2024-02-08 10:41:25 -07:00
|
|
|
top_n_tokens = [
|
|
|
|
min(tok, logprobs.size(-1))
|
|
|
|
for tok in top_n_tokens
|
|
|
|
for _ in range(speculate_size)
|
|
|
|
]
|
2023-08-28 03:43:47 -06:00
|
|
|
|
|
|
|
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
|
|
|
# Sorted topk is faster than torch.sort() since we only need a small subset
|
2024-01-26 12:13:47 -07:00
|
|
|
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values
|
|
|
|
|
2023-08-28 03:43:47 -06:00
|
|
|
nth_highest = torch.gather(
|
|
|
|
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
|
|
|
|
)
|
|
|
|
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
|
|
|
|
|
|
|
|
# Find the new "fuzzy" top n values
|
|
|
|
top_n_indices = (logprobs >= nth_highest).nonzero()
|
|
|
|
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
|
2023-09-27 04:22:09 -06:00
|
|
|
|
2023-08-31 16:22:03 -06:00
|
|
|
k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max()
|
2023-08-28 03:43:47 -06:00
|
|
|
# Take a new topk for these new max n values
|
2023-08-31 16:22:03 -06:00
|
|
|
top_k = torch.topk(logprobs, k=k, dim=1, sorted=True)
|
2023-08-28 03:43:47 -06:00
|
|
|
|
|
|
|
top_n_ishes = top_n_ishes.tolist()
|
|
|
|
top_indices = top_k.indices.tolist()
|
|
|
|
top_values = top_k.values.tolist()
|
|
|
|
|
2024-01-26 12:13:47 -07:00
|
|
|
batch_top_token_ids = []
|
|
|
|
batch_top_token_logprobs = []
|
|
|
|
accepted_ids_list = accepted_ids.tolist()
|
|
|
|
for i, n_accepted_ids in enumerate(accepted_ids_list):
|
|
|
|
start = speculate_size * i
|
|
|
|
stop = speculate_size * (i + 1)
|
2024-02-08 10:41:25 -07:00
|
|
|
_top_indices = top_indices[start:stop]
|
|
|
|
_top_values = top_values[start:stop]
|
|
|
|
_top_n_ishes = top_n_ishes[start:stop]
|
|
|
|
_top_n_tokens = top_n_tokens[start:stop]
|
2024-01-26 12:13:47 -07:00
|
|
|
|
|
|
|
_top_indices = _top_indices[:n_accepted_ids]
|
|
|
|
_top_values = _top_values[:n_accepted_ids]
|
|
|
|
_top_n_ishes = _top_n_ishes[:n_accepted_ids]
|
|
|
|
_top_n_tokens = _top_n_tokens[:n_accepted_ids]
|
|
|
|
|
|
|
|
row_top_token_ids = []
|
|
|
|
row_top_token_logprobs = []
|
|
|
|
|
2024-02-08 10:41:25 -07:00
|
|
|
for idxs, vals, n, req_n in zip(
|
|
|
|
_top_indices, _top_values, _top_n_ishes, _top_n_tokens
|
|
|
|
):
|
2024-01-26 12:13:47 -07:00
|
|
|
indices = idxs[:n] if req_n > 0 else []
|
|
|
|
values = vals[:n] if req_n > 0 else []
|
|
|
|
|
|
|
|
row_top_token_ids.append(indices)
|
|
|
|
row_top_token_logprobs.append(values)
|
|
|
|
|
|
|
|
batch_top_token_ids.append(row_top_token_ids)
|
|
|
|
batch_top_token_logprobs.append(row_top_token_logprobs)
|
|
|
|
|
|
|
|
return batch_top_token_ids, batch_top_token_logprobs
|