hf_text-generation-inference/server/text_generation_server/utils/logits_process.py

411 lines
15 KiB
Python
Raw Permalink Normal View History

import math
import torch
from functools import lru_cache
from typing import Optional, List, Dict, Union
from transformers import (
LogitsWarper,
LogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
)
mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
class StaticWarper:
def __init__(
self,
temperature=1.0,
top_k=None,
top_p=None,
typical_p=None,
):
self.warpers = []
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
self.warpers.append(TemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
self.warpers.append(TopKLogitsWarper(top_k=top_k))
if top_p is not None and top_p < 1.0:
self.warpers.append(TopPLogitsWarper(top_p=top_p))
if typical_p is not None and typical_p < 1.0:
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
self.cuda_graph = None
self.static_scores = None
self.static_warped_scores = None
self.static_next_logprob = None
def __call__(self, scores):
if torch.cuda.is_available():
if self.cuda_graph is None:
self.static_scores = scores
self.cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cuda_graph, pool=mempool):
local_scores = self.static_scores
for warper in self.warpers:
local_scores = warper(None, local_scores)
self.static_warped_scores = local_scores
# Compute logprobs
self.static_next_logprob = torch.log_softmax(
self.static_warped_scores, -1
)
self.static_scores.copy_(scores)
self.cuda_graph.replay()
return self.static_warped_scores, self.static_next_logprob
# CPU branch
for warper in self.warpers:
scores = warper(None, scores)
return scores, torch.log_softmax(scores, -1)
@lru_cache(10)
def static_warper(
temperature: Optional[float],
top_k: Optional[int],
top_p: Optional[float],
typical_p: Optional[float],
) -> StaticWarper:
return StaticWarper(
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
)
class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
repetition_penalty (`List[float]`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
"""
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
self.penalty = penalty
self.penalty_tensor = torch.tensor(
penalty, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
)
scores.scatter_(1, input_ids, score)
return scores
def filter(self, indices):
self.penalty = [self.penalty[i] for i in indices]
if any([x != 1.0 for x in self.penalty]):
self.penalty_tensor = self.penalty_tensor[indices]
return self
return None
class HeterogeneousTemperatureLogitsWarper:
r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
temperature (`float`):
The value used to module the logits distribution.
"""
def __init__(
self, temperature: List[float], dtype: torch.dtype, device: torch.device
):
self.temperature = temperature
self.temperature_tensor = torch.tensor(
temperature, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
scores.div_(self.temperature_tensor)
return scores
def filter(self, indices):
self.temperature = [self.temperature[i] for i in indices]
if any([x != 1.0 for x in self.temperature]):
self.temperature_tensor = self.temperature_tensor[indices]
return self
return None
class HeterogeneousTopPLogitsWarper(LogitsWarper):
"""
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(
self,
top_p: List[float],
dtype: torch.dtype,
device: torch.device,
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.top_p = top_p
self.top_p_opposite = 1 - torch.tensor(
top_p, dtype=dtype, device=device
).unsqueeze(1)
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
probs = sorted_logits.softmax(dim=-1)
# This is way faster for some reason
for i in range(probs.shape[0]):
probs[i] = probs[i].cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = probs <= self.top_p_opposite
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
return warped_scores
def filter(self, indices):
self.top_p = [self.top_p[i] for i in indices]
if any([x < 1.0 for x in self.top_p]):
self.top_p_opposite = self.top_p_opposite[indices]
return self
return None
class HeterogeneousTopKLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(
self,
top_k: List[int],
device: torch.device,
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.top_k = top_k
self.max_top_k = max(top_k)
# value - 1 as we will use top_k to index and python uses 0 based numbering
self.top_k_tensor = torch.tensor(
[max(x - 1, min_tokens_to_keep - 1) for x in top_k],
dtype=torch.int64,
device=device,
).unsqueeze(1)
# 0 is a special value that disables top_k warping for this member of the batch
disabled = [x == 0 for x in top_k]
if any(disabled):
self.top_k_disabled_mask = torch.tensor(
disabled, dtype=torch.bool, device=device
).view(-1, 1)
else:
self.top_k_disabled_mask = None
self.filter_value = filter_value
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
# If max_top_k is superior to the vocab, we need to clamp or the warper will fail
if scores.size(-1) < self.max_top_k:
max_top_k = scores.size(-1)
top_k = torch.clamp_max(self.top_k_tensor, max_top_k)
else:
max_top_k = self.max_top_k
top_k = self.top_k_tensor
# Get the kth score for each member of the batch
kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
# Mask member of kth_scores that do not want to use top_k warping
if self.top_k_disabled_mask is not None:
kth_scores.masked_fill_(self.top_k_disabled_mask, self.filter_value)
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < kth_scores
scores.masked_fill_(indices_to_remove, self.filter_value)
return scores
def filter(self, indices):
self.top_k = [self.top_k[i] for i in indices]
disabled = [x == 0 for x in self.top_k]
if not all(disabled):
self.top_k_tensor = self.top_k_tensor[indices]
self.max_top_k = max(self.top_k)
if self.top_k_disabled_mask is not None:
self.top_k_disabled_mask = (
self.top_k_disabled_mask[indices] if any(disabled) else None
)
return self
return None
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
Generation](https://arxiv.org/abs/2202.00666) for more information.
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
mass (`float`):
Value of typical_p between 0 and 1 inclusive, defaults to 0.9.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(
self,
mass: List[float],
dtype: torch.dtype,
device: torch.device,
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.mass = mass
self.mass_tensor = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)
# 1 is a special value that disables typical_p warping for this member of the batch
disabled = [x == 1.0 for x in mass]
if any(disabled):
self.disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device)
else:
self.disabled_mask = None
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
# calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = scores.gather(-1, sorted_indices)
probs = sorted_logits.softmax(dim=-1)
# This is way faster for some reason
for i in range(probs.shape[0]):
probs[i] = probs[i].cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold
last_ind = (probs < self.mass_tensor).sum(dim=1)
last_ind[last_ind < 0] = 0
if self.disabled_mask is not None:
last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
1, last_ind.view(-1, 1)
)
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
return warped_scores
def filter(self, indices):
self.mass = [self.mass[i] for i in indices]
disabled = [x == 1.0 for x in self.mass]
if not all(disabled):
self.mass_tensor = self.mass_tensor[indices]
if self.disabled_mask is not None:
self.disabled_mask = (
self.disabled_mask[indices] if any(disabled) else None
)
return self
return None
class HeterogeneousProcessorWrapper(LogitsProcessor):
r"""
A wrapper for logit warpers or processors without heterogeneous parameter support.
Args:
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`):
A mapping of sample indices to logit warpers or processors, to be run sequentially.
"""
def __init__(
self,
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]],
):
self.processors = processors
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
for i, processor in self.processors.items():
scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1])
return scores
def filter(self, indices):
new_processors = {}
for i, idx in enumerate(indices):
if idx in self.processors:
new_processors[i] = self.processors[idx]
if new_processors:
self.processors = new_processors
return self
return None