hf_text-generation-inference/server/text_generation/utils/tokens.py

154 lines
4.7 KiB
Python
Raw Normal View History

2023-02-14 05:02:16 -07:00
import re
import torch
from transformers import (
LogitsProcessorList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
RepetitionPenaltyLogitsProcessor,
PreTrainedTokenizerBase,
)
from typing import List, Tuple, Optional
from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
from text_generation.utils.watermark import WatermarkLogitsProcessor
2023-02-14 05:02:16 -07:00
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits)
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
return next_tokens
class Greedy:
def __call__(self, logits):
return logits.argmax()
class NextTokenChooser:
def __init__(
self,
vocab_size,
watermark=False,
2023-02-14 05:02:16 -07:00
temperature=1.0,
repetition_penalty=1.0,
top_k=None,
top_p=None,
do_sample=False,
seed=0,
device="cpu",
):
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
sampling = do_sample
if watermark:
warpers.append(WatermarkLogitsProcessor(vocab_size, device=device))
if repetition_penalty is not None and repetition_penalty != 1.0:
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
2023-02-14 05:02:16 -07:00
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
warpers.append(TemperatureLogitsWarper(temperature))
sampling = True
if top_k is not None and top_k != 0:
warpers.append(TopKLogitsWarper(top_k=top_k))
sampling = True
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p))
sampling = True
self.warpers = warpers
self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores):
# Warp logits
scores = self.warpers(input_ids, scores)
# Compute logprobs
logprobs = torch.log_softmax(scores, -1)
# Choose tokens
next_id = self.choice(scores[-1])
return next_id.view(1, 1), logprobs
@classmethod
def from_pb(
cls,
pb: generate_pb2.NextTokenChooserParameters,
vocab_size: int,
device: torch.device,
2023-02-14 05:02:16 -07:00
) -> "NextTokenChooser":
return NextTokenChooser(
vocab_size=vocab_size,
watermark=pb.watermark,
2023-02-14 05:02:16 -07:00
temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,
top_k=pb.top_k,
top_p=pb.top_p,
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
)
class StopSequenceCriteria:
def __init__(self, stop_sequence: str):
self.regex = re.compile(f".*{stop_sequence}$")
def __call__(self, output: str) -> bool:
if self.regex.findall(output):
return True
return False
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20,
):
self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = ""
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if last_token == self.eos_token_id:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
self.current_output += last_output
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(self.current_output):
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
return False, None
@classmethod
def from_pb(
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria":
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
return StoppingCriteria(
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
)