import re import torch from transformers import ( RepetitionPenaltyLogitsProcessor, PreTrainedTokenizerBase, ) from typing import List, Tuple, Optional from text_generation_server.pb import generate_pb2 from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.logits_process import ( static_warper, HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousTemperatureLogitsWarper, HeterogeneousTopKLogitsWarper, HeterogeneousTopPLogitsWarper, HeterogeneousTypicalLogitsWarper, HeterogeneousProcessorWrapper, ) class NextTokenChooser: def __init__( self, watermark=False, temperature=1.0, repetition_penalty=1.0, top_k=None, top_p=None, typical_p=None, do_sample=False, seed=0, device="cpu", ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None ) self.repetition_processor = ( RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) if repetition_penalty else None ) has_warpers = ( (temperature is not None and temperature != 1.0) or (top_k is not None and top_k != 0) or (top_p is not None and top_p < 1.0) or (typical_p is not None and typical_p < 1.0) ) if has_warpers: self.static_warper = static_warper( temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p ) else: self.static_warper = None sampling = do_sample or has_warpers self.choice = Sampling(seed, device) if sampling else Greedy() def __call__(self, input_ids, scores): if self.watermark_processor is not None: scores = self.watermark_processor(input_ids, scores) if self.repetition_processor is not None: scores = self.repetition_processor(input_ids, scores) if self.static_warper is None: next_logprob = torch.log_softmax(scores, -1) else: scores, next_logprob = self.static_warper(scores) next_id = self.choice(scores[-1]).view(1, 1) return next_id, next_logprob @classmethod def from_pb( cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device, ) -> "NextTokenChooser": return NextTokenChooser( watermark=pb.watermark, temperature=pb.temperature, repetition_penalty=pb.repetition_penalty, top_k=pb.top_k, top_p=pb.top_p, typical_p=pb.typical_p, do_sample=pb.do_sample, seed=pb.seed, device=device, ) class StopSequenceCriteria: def __init__(self, stop_sequence: str): stop_sequence = re.escape(stop_sequence) self.regex = re.compile(f".*{stop_sequence}$") def __call__(self, output: str) -> bool: if self.regex.findall(output): return True return False class StoppingCriteria: def __init__( self, eos_token_id: int, stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens: int = 20, ignore_eos_token: bool = False, ): self.eos_token_id = eos_token_id self.stop_sequence_criterias = stop_sequence_criterias self.max_new_tokens = max_new_tokens self.current_tokens = 0 self.current_output = "" self.ignore_eos_token = ignore_eos_token def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]: self.current_tokens += 1 if self.current_tokens >= self.max_new_tokens: return True, FinishReason.FINISH_REASON_LENGTH if not self.ignore_eos_token and last_token == self.eos_token_id: return True, FinishReason.FINISH_REASON_EOS_TOKEN self.current_output += last_output for stop_sequence_criteria in self.stop_sequence_criterias: if stop_sequence_criteria(self.current_output): return True, FinishReason.FINISH_REASON_STOP_SEQUENCE return False, None @classmethod def from_pb( cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: PreTrainedTokenizerBase, ) -> "StoppingCriteria": stop_sequence_criterias = [ StopSequenceCriteria(sequence) for sequence in pb.stop_sequences ] return StoppingCriteria( tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens, pb.ignore_eos_token, ) class HeterogeneousNextTokenChooser: def __init__( self, dtype: torch.dtype, device: torch.device, watermark: List[bool], temperature: List[float], repetition_penalty: List[float], top_k: List[int], top_p: List[float], typical_p: List[float], do_sample: List[bool], seeds: List[int], ): warpers = [] self.watermark_processor = ( HeterogeneousProcessorWrapper( { i: WatermarkLogitsProcessor(device=device) for i, do_watermark in enumerate(watermark) if do_watermark } ) if any(watermark) else None ) self.repetition_processor = ( HeterogeneousRepetitionPenaltyLogitsProcessor( repetition_penalty, dtype, device ) if any([x != 1.0 for x in repetition_penalty]) else None ) if any([x != 1.0 for x in temperature]): do_sample = [ sample or x != 1.0 for x, sample in zip(temperature, do_sample) ] warpers.append( HeterogeneousTemperatureLogitsWarper(temperature, dtype, device) ) if any([x != 0 for x in top_k]): do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)] warpers.append(HeterogeneousTopKLogitsWarper(top_k, device)) if any([x < 1.0 for x in top_p]): do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)] warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device)) if any([x < 1.0 for x in typical_p]): do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)] warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device)) self.warpers = warpers if any(do_sample): self.choice = HeterogeneousSampling(do_sample, seeds, device) else: self.choice = Greedy() self.seeds = seeds self.do_sample = do_sample def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): if self.watermark_processor is not None: scores = self.watermark_processor(input_ids, scores) if self.repetition_processor is not None: scores = self.repetition_processor(input_ids, scores) for warper in self.warpers: scores = warper(input_ids, scores) next_ids = self.choice(scores) next_logprobs = torch.gather( torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1) ).view(-1) return next_ids, next_logprobs def filter(self, indices): if self.watermark_processor is not None: self.watermark_processor = self.watermark_processor.filter(indices) if self.repetition_processor is not None: self.repetition_processor = self.repetition_processor.filter(indices) filtered_warpers = [] for warper in self.warpers: filtered_warper = warper.filter(indices) if filtered_warper is not None: filtered_warpers.append(filtered_warper) self.warpers = filtered_warpers self.seeds = [self.seeds[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices] if any(self.do_sample): self.choice.filter(indices) else: self.choice = Greedy() return self @classmethod def from_pb( cls, pb: List[generate_pb2.NextTokenChooserParameters], dtype: torch.dtype, device: torch.device, ) -> "HeterogeneousNextTokenChooser": return HeterogeneousNextTokenChooser( watermark=[pb_.watermark for pb_ in pb], temperature=[pb_.temperature for pb_ in pb], repetition_penalty=[pb_.repetition_penalty for pb_ in pb], top_k=[pb_.top_k for pb_ in pb], top_p=[pb_.top_p for pb_ in pb], typical_p=[pb_.typical_p for pb_ in pb], do_sample=[pb_.do_sample for pb_ in pb], seeds=[pb_.seed for pb_ in pb], device=device, dtype=dtype, ) class Sampling: def __init__(self, seed: int, device: str = "cpu"): self.generator = torch.Generator(device) self.generator.manual_seed(seed) self.seed = seed def __call__(self, logits): probs = torch.nn.functional.softmax(logits, -1) # Avoid GPU<->CPU sync done by torch multinomial # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637 q = torch.empty_like(probs).exponential_(1, generator=self.generator) return probs.div_(q).argmax() class Greedy: def __call__(self, logits): return logits.argmax(dim=-1) class HeterogeneousSampling: r""" Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample. """ def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device): self.seeds = seeds self.greedy_indices = [] self.sampling_mapping = {} for i, (sample, seed) in enumerate(zip(do_sample, seeds)): if sample: self.sampling_mapping[i] = Sampling(seed, device) else: self.greedy_indices.append(i) self.greedy = Greedy() def __call__(self, logits): out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device) if self.greedy_indices: # Computing for all indices is faster than slicing torch.argmax(logits, -1, out=out) for i, sampling in self.sampling_mapping.items(): out[i] = sampling(logits[i]) return out def filter(self, indices): new_greedy_indices = [] new_sampling_mapping = {} for i, idx in enumerate(indices): if idx in self.sampling_mapping: new_sampling_mapping[i] = self.sampling_mapping[idx] else: new_greedy_indices.append(i) self.greedy_indices = new_greedy_indices self.sampling_mapping = new_sampling_mapping return self