import math import torch from loguru import logger from typing import Dict, Union from text_generation_server.pb.generate_pb2 import GrammarType from outlines.fsm.fsm import RegexFSM from outlines.fsm.json_schema import build_regex_from_schema from functools import lru_cache from typing import List, Optional, DefaultDict import time from transformers import ( LogitsWarper, LogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, ) mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None class StaticWarper: def __init__( self, temperature=1.0, top_k=None, top_p=None, typical_p=None, ): self.warpers = [] if temperature is not None and temperature != 1.0: temperature = float(temperature) self.warpers.append(TemperatureLogitsWarper(temperature)) if top_k is not None and top_k != 0: self.warpers.append(TopKLogitsWarper(top_k=top_k)) if top_p is not None and top_p < 1.0: self.warpers.append(TopPLogitsWarper(top_p=top_p)) if typical_p is not None and typical_p < 1.0: self.warpers.append(TypicalLogitsWarper(mass=typical_p)) self.cuda_graph = None self.static_scores = None self.static_warped_scores = None self.static_next_logprob = None def __call__(self, scores): if torch.cuda.is_available(): if self.cuda_graph is None: self.static_scores = scores self.cuda_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.cuda_graph, pool=mempool): local_scores = self.static_scores for warper in self.warpers: local_scores = warper(None, local_scores) self.static_warped_scores = local_scores # Compute logprobs self.static_next_logprob = torch.log_softmax( self.static_warped_scores, -1 ) self.static_scores.copy_(scores) self.cuda_graph.replay() return self.static_warped_scores, self.static_next_logprob # CPU branch for warper in self.warpers: scores = warper(None, scores) return scores, torch.log_softmax(scores, -1) @lru_cache(10) def static_warper( temperature: Optional[float], top_k: Optional[int], top_p: Optional[float], typical_p: Optional[float], ) -> StaticWarper: return StaticWarper( temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p ) class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): r""" [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences. This version allows for a separate value for each sample and runs inplace when possible. It doesn't validate inputs. Args: repetition_penalty (`List[float]`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. """ def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device): self.penalty = penalty self.penalty_tensor = torch.tensor( penalty, dtype=dtype, device=device ).unsqueeze(1) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: score = torch.gather(scores, 1, input_ids) # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability score = torch.where( score < 0, score * self.penalty_tensor, score / self.penalty_tensor ) scores.scatter_(1, input_ids, score) return scores def filter(self, indices): self.penalty = [self.penalty[i] for i in indices] if any([x != 1.0 for x in self.penalty]): self.penalty_tensor = self.penalty_tensor[indices] return self return None class FrequencyPenaltyLogitsProcessor(LogitsProcessor): r""" Frequency penalty as defined by OpenAI Args: penalty (`float`): The parameter for frequency penalty. 0.0 means no penalty. """ def __init__(self, penalty: float): self.penalty = penalty def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor ) -> torch.FloatTensor: score = torch.gather(scores, 1, input_ids) # if score < 0 then penalty has to be multiplied to reduce the previous token probability score = -torch.where(score < 0, score * self.penalty, score / self.penalty) # set score to 0 where input_ids is a padding token score *= input_ids.ne(0) return scores.scatter_add_(1, input_ids, score) class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor): r""" Frequency penalty as defined by OpenAI in https://platform.openai.com/docs/guides/text-generation/parameter-details Args: frequency_penalty (`List[float]`): The parameter for frequency penalty. 0.0 means no penalty. """ def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device): self.penalty = penalty self.penalty_tensor = torch.tensor( penalty, dtype=dtype, device=device ).unsqueeze(1) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: batch_size, input_size = input_ids.size() vocab_size = scores.size(1) # Calculate the frequency for each token so far token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device) token_freq.scatter_add_( 1, input_ids, torch.ones_like(input_ids, dtype=torch.float) ) token_freq /= input_size # Apply the frequency penalty to logits scores -= token_freq * self.penalty_tensor return scores def filter(self, indices): self.penalty = [self.penalty[i] for i in indices] if any([x != 0.0 for x in self.penalty]): self.penalty_tensor = self.penalty_tensor[indices] return self return None class HeterogeneousTemperatureLogitsWarper: r""" [`LogitsWarper`] for temperature (exponential scaling output probability distribution). This version allows for a separate value for each sample and runs inplace when possible. It doesn't validate inputs. Args: temperature (`float`): The value used to module the logits distribution. """ def __init__( self, temperature: List[float], dtype: torch.dtype, device: torch.device ): self.temperature = temperature self.temperature_tensor = torch.tensor( temperature, dtype=dtype, device=device ).unsqueeze(1) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: scores.div_(self.temperature_tensor) return scores def filter(self, indices): self.temperature = [self.temperature[i] for i in indices] if any([x != 1.0 for x in self.temperature]): self.temperature_tensor = self.temperature_tensor[indices] return self return None class HeterogeneousTopPLogitsWarper(LogitsWarper): """ [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. This version allows for a separate value for each sample and runs inplace when possible. It doesn't validate inputs. Args: top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. filter_value (`float`, *optional*, defaults to `-float("Inf")`): All filtered values will be set to this float value. min_tokens_to_keep (`int`, *optional*, defaults to 1): Minimum number of tokens that cannot be filtered. """ def __init__( self, top_p: List[float], dtype: torch.dtype, device: torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1, ): self.top_p = top_p self.top_p_opposite = 1 - torch.tensor( top_p, dtype=dtype, device=device ).unsqueeze(1) self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: sorted_logits, sorted_indices = torch.sort(scores, descending=False) probs = sorted_logits.softmax(dim=-1) # This is way faster for some reason for i in range(probs.shape[0]): probs[i] = probs[i].cumsum(dim=-1) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) sorted_indices_to_remove = probs <= self.top_p_opposite # Keep at least min_tokens_to_keep sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove ) warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) return warped_scores def filter(self, indices): self.top_p = [self.top_p[i] for i in indices] if any([x < 1.0 for x in self.top_p]): self.top_p_opposite = self.top_p_opposite[indices] return self return None class HeterogeneousTopKLogitsWarper(LogitsWarper): r""" [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. This version allows for a separate value for each sample and runs inplace when possible. It doesn't validate inputs. Args: top_k (`int`): The number of highest probability vocabulary tokens to keep for top-k-filtering. filter_value (`float`, *optional*, defaults to `-float("Inf")`): All filtered values will be set to this float value. min_tokens_to_keep (`int`, *optional*, defaults to 1): Minimum number of tokens that cannot be filtered. """ def __init__( self, top_k: List[int], device: torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1, ): self.top_k = top_k self.max_top_k = max(top_k) # value - 1 as we will use top_k to index and python uses 0 based numbering self.top_k_tensor = torch.tensor( [max(x - 1, min_tokens_to_keep - 1) for x in top_k], dtype=torch.int64, device=device, ).unsqueeze(1) # 0 is a special value that disables top_k warping for this member of the batch disabled = [x == 0 for x in top_k] if any(disabled): self.top_k_disabled_mask = torch.tensor( disabled, dtype=torch.bool, device=device ).view(-1, 1) else: self.top_k_disabled_mask = None self.filter_value = filter_value def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: # If max_top_k is superior to the vocab, we need to clamp or the warper will fail if scores.size(-1) < self.max_top_k: max_top_k = scores.size(-1) top_k = torch.clamp_max(self.top_k_tensor, max_top_k) else: max_top_k = self.max_top_k top_k = self.top_k_tensor # Get the kth score for each member of the batch kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k) # Mask member of kth_scores that do not want to use top_k warping if self.top_k_disabled_mask is not None: kth_scores.masked_fill_(self.top_k_disabled_mask, self.filter_value) # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = scores < kth_scores scores.masked_fill_(indices_to_remove, self.filter_value) return scores def filter(self, indices): self.top_k = [self.top_k[i] for i in indices] disabled = [x == 0 for x in self.top_k] if not all(disabled): self.top_k_tensor = self.top_k_tensor[indices] self.max_top_k = max(self.top_k) if self.top_k_disabled_mask is not None: self.top_k_disabled_mask = ( self.top_k_disabled_mask[indices] if any(disabled) else None ) return self return None class HeterogeneousTypicalLogitsWarper(LogitsWarper): r""" [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information. This version allows for a separate value for each sample and runs inplace when possible. It doesn't validate inputs. Args: mass (`float`): Value of typical_p between 0 and 1 inclusive, defaults to 0.9. filter_value (`float`, *optional*, defaults to `-float("Inf")`): All filtered values will be set to this float value. min_tokens_to_keep (`int`, *optional*, defaults to 1): Minimum number of tokens that cannot be filtered. """ def __init__( self, mass: List[float], dtype: torch.dtype, device: torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1, ): self.mass = mass self.mass_tensor = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1) # 1 is a special value that disables typical_p warping for this member of the batch disabled = [x == 1.0 for x in mass] if any(disabled): self.disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device) else: self.disabled_mask = None self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: # calculate entropy normalized = torch.nn.functional.log_softmax(scores, dim=-1) p = torch.exp(normalized) ent = -(normalized * p).nansum(-1, keepdim=True) # shift and sort shifted_scores = torch.abs((-normalized) - ent) sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) sorted_logits = scores.gather(-1, sorted_indices) probs = sorted_logits.softmax(dim=-1) # This is way faster for some reason for i in range(probs.shape[0]): probs[i] = probs[i].cumsum(dim=-1) # Remove tokens with cumulative mass above the threshold last_ind = (probs < self.mass_tensor).sum(dim=1) last_ind[last_ind < 0] = 0 if self.disabled_mask is not None: last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1) sorted_indices_to_remove = sorted_scores > sorted_scores.gather( 1, last_ind.view(-1, 1) ) if self.min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove ) warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) return warped_scores def filter(self, indices): self.mass = [self.mass[i] for i in indices] disabled = [x == 1.0 for x in self.mass] if not all(disabled): self.mass_tensor = self.mass_tensor[indices] if self.disabled_mask is not None: self.disabled_mask = ( self.disabled_mask[indices] if any(disabled) else None ) return self return None class HeterogeneousProcessorWrapper(LogitsProcessor): r""" A wrapper for logit warpers or processors without heterogeneous parameter support. Args: processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`): A mapping of sample indices to logit warpers or processors, to be run sequentially. """ def __init__( self, processors: Dict[int, Union[LogitsProcessor, LogitsWarper]], ): self.processors = processors def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: for i, processor in self.processors.items(): scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1]) return scores def filter(self, indices): new_processors = {} for i, idx in enumerate(indices): if idx in self.processors: new_processors[i] = self.processors[idx] if new_processors: self.processors = new_processors return self return None class GrammarLogitProcessor(LogitsProcessor): fsm_state: DefaultDict[int, int] fsm: RegexFSM def __init__(self, tokenizer, device, grammar, grammar_type): self.device = device self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) self.fsm = GrammarLogitProcessor._cached_compile_fsm( grammar_type, grammar, self.tokenizer ) def __call__( self, logits: torch.Tensor, fsm_grammar_state: int, ): if fsm_grammar_state == -1 or self.fsm is None: return logits allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state) mask = torch.full_like(logits, -math.inf) mask[:, allowed_tokens] = 0 biased_scores = logits + mask return biased_scores def advance(self, next_token_id, fsm_grammar_state): return GrammarLogitProcessor._advance( next_token_id, fsm_grammar_state, self.fsm ) @staticmethod def _advance(next_token_id, fsm_grammar_state, fsm): if fsm_grammar_state == -1: return fsm_grammar_state return fsm.next_state(fsm_grammar_state, next_token_id) # TODO: move grammar compilation into the router @staticmethod @lru_cache(maxsize=32, typed=True) def _cached_compile_fsm(grammar_type, schema, tokenizer): start_time = time.time() if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: try: schema = build_regex_from_schema(schema) # TODO: this is only here short term to avoid crashing the python server, mid term we want this in the rust/router layer except Exception as e: logger.error(f"Error compiling FSM, grammar won't be enforced \n{e}") # allows everything schema = "(.*?)" elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: pass # schema is already a regex just here for clarity fsm = RegexFSM(schema, tokenizer) logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s") return fsm @staticmethod @lru_cache(maxsize=32, typed=True) def _cached_adapt_tokenizer(tokenizer): """Adapt tokenizer to work with the FSM. The API of Outlines tokenizers is slightly different to that of `transformers`. In addition we need to handle the missing spaces to Llama's tokenizer to be able to compile FSMs for this model. """ start_time = time.time() tokenizer.vocabulary = tokenizer.get_vocab() tokenizer.special_tokens = set(tokenizer.all_special_tokens) def convert_token_to_string(token: str) -> str: from transformers.file_utils import SPIECE_UNDERLINE string = tokenizer.convert_tokens_to_string([token]) # A hack to handle missing spaces to HF's Llama tokenizers if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": return " " + string return string tokenizer.convert_token_to_string = convert_token_to_string logger.debug(f"Adapted tokenizer in {time.time() - start_time:.2f}s") return tokenizer class HeterogeneousGrammarLogitProcessor(LogitsProcessor): def __init__(self, tokenizer, device, grammars, grammar_types): self.device = device self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) self.fsms = [] for grammar, grammar_type in zip(grammars, grammar_types): if len(grammar) == 0: self.fsms.append(None) continue fsm = GrammarLogitProcessor._cached_compile_fsm( grammar_type, grammar, self.tokenizer ) self.fsms.append(fsm) def __call__( self, logits: torch.Tensor, fsm_grammar_states: List[int], ): mask = torch.full_like(logits, -math.inf) for i in range(logits.shape[0]): fsm = self.fsms[i] if fsm_grammar_states[i] == -1 or fsm is None: continue allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) mask[i, allowed_tokens] = 0 logits[i] += mask[i] return logits def advance_batch(self, next_token_ids, fsm_grammar_states): return [ GrammarLogitProcessor._advance( next_token_ids[i], fsm_grammar_states[i], self.fsms[i] ) for i in range(len(next_token_ids)) ] def advance_at_index(self, next_token_id, fsm_grammar_state, index): if self.fsms[index] is None: return fsm_grammar_state return GrammarLogitProcessor._advance( next_token_id, fsm_grammar_state, self.fsms[index] ) def filter(self, indices): new_fsms = [] for i in indices: new_fsms.append(self.fsms[i]) self.fsms = new_fsms return self