2023-05-26 04:30:27 -06:00
|
|
|
import math
|
|
|
|
import torch
|
|
|
|
|
2024-02-15 02:28:10 -07:00
|
|
|
from loguru import logger
|
2024-02-16 09:50:57 -07:00
|
|
|
from typing import Dict, Union
|
2024-02-15 02:28:10 -07:00
|
|
|
from text_generation_server.pb.generate_pb2 import GrammarType
|
|
|
|
|
|
|
|
from outlines.fsm.fsm import RegexFSM
|
|
|
|
from outlines.fsm.json_schema import build_regex_from_object
|
|
|
|
from functools import lru_cache
|
|
|
|
from typing import List, Optional, DefaultDict
|
|
|
|
import time
|
2023-05-26 04:30:27 -06:00
|
|
|
|
|
|
|
from transformers import (
|
|
|
|
LogitsWarper,
|
|
|
|
LogitsProcessor,
|
|
|
|
TemperatureLogitsWarper,
|
|
|
|
TopKLogitsWarper,
|
|
|
|
TopPLogitsWarper,
|
|
|
|
TypicalLogitsWarper,
|
|
|
|
)
|
|
|
|
|
|
|
|
mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
|
|
|
|
|
|
|
|
|
|
|
class StaticWarper:
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
temperature=1.0,
|
|
|
|
top_k=None,
|
|
|
|
top_p=None,
|
|
|
|
typical_p=None,
|
|
|
|
):
|
|
|
|
self.warpers = []
|
|
|
|
|
|
|
|
if temperature is not None and temperature != 1.0:
|
|
|
|
temperature = float(temperature)
|
|
|
|
self.warpers.append(TemperatureLogitsWarper(temperature))
|
|
|
|
if top_k is not None and top_k != 0:
|
|
|
|
self.warpers.append(TopKLogitsWarper(top_k=top_k))
|
|
|
|
if top_p is not None and top_p < 1.0:
|
|
|
|
self.warpers.append(TopPLogitsWarper(top_p=top_p))
|
|
|
|
if typical_p is not None and typical_p < 1.0:
|
|
|
|
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
|
|
|
|
|
|
|
|
self.cuda_graph = None
|
|
|
|
self.static_scores = None
|
|
|
|
self.static_warped_scores = None
|
|
|
|
self.static_next_logprob = None
|
|
|
|
|
|
|
|
def __call__(self, scores):
|
2023-06-20 03:06:10 -06:00
|
|
|
if torch.cuda.is_available():
|
|
|
|
if self.cuda_graph is None:
|
|
|
|
self.static_scores = scores
|
|
|
|
self.cuda_graph = torch.cuda.CUDAGraph()
|
|
|
|
|
|
|
|
with torch.cuda.graph(self.cuda_graph, pool=mempool):
|
|
|
|
local_scores = self.static_scores
|
|
|
|
for warper in self.warpers:
|
|
|
|
local_scores = warper(None, local_scores)
|
|
|
|
|
|
|
|
self.static_warped_scores = local_scores
|
|
|
|
# Compute logprobs
|
|
|
|
self.static_next_logprob = torch.log_softmax(
|
|
|
|
self.static_warped_scores, -1
|
|
|
|
)
|
|
|
|
|
|
|
|
self.static_scores.copy_(scores)
|
|
|
|
self.cuda_graph.replay()
|
|
|
|
|
|
|
|
return self.static_warped_scores, self.static_next_logprob
|
|
|
|
|
|
|
|
# CPU branch
|
|
|
|
for warper in self.warpers:
|
|
|
|
scores = warper(None, scores)
|
|
|
|
return scores, torch.log_softmax(scores, -1)
|
2023-05-26 04:30:27 -06:00
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(10)
|
|
|
|
def static_warper(
|
|
|
|
temperature: Optional[float],
|
|
|
|
top_k: Optional[int],
|
|
|
|
top_p: Optional[float],
|
|
|
|
typical_p: Optional[float],
|
|
|
|
) -> StaticWarper:
|
|
|
|
return StaticWarper(
|
|
|
|
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
|
|
|
r"""
|
|
|
|
[`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
|
|
|
|
This version allows for a separate value for each sample and runs inplace when possible.
|
|
|
|
It doesn't validate inputs.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
repetition_penalty (`List[float]`):
|
|
|
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
|
|
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
|
|
|
|
self.penalty = penalty
|
|
|
|
self.penalty_tensor = torch.tensor(
|
|
|
|
penalty, dtype=dtype, device=device
|
|
|
|
).unsqueeze(1)
|
|
|
|
|
|
|
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
|
|
|
score = torch.gather(scores, 1, input_ids)
|
|
|
|
|
|
|
|
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
|
|
|
score = torch.where(
|
|
|
|
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
|
|
|
)
|
|
|
|
|
|
|
|
scores.scatter_(1, input_ids, score)
|
|
|
|
return scores
|
|
|
|
|
|
|
|
def filter(self, indices):
|
|
|
|
self.penalty = [self.penalty[i] for i in indices]
|
|
|
|
if any([x != 1.0 for x in self.penalty]):
|
|
|
|
self.penalty_tensor = self.penalty_tensor[indices]
|
|
|
|
return self
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
2024-02-08 10:41:25 -07:00
|
|
|
class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
|
|
|
r"""
|
|
|
|
Frequency penalty as defined by OpenAI
|
|
|
|
|
|
|
|
Args:
|
|
|
|
penalty (`float`):
|
|
|
|
The parameter for frequency penalty. 0.0 means no penalty.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, penalty: float):
|
|
|
|
self.penalty = penalty
|
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
|
|
|
) -> torch.FloatTensor:
|
|
|
|
score = torch.gather(scores, 1, input_ids)
|
|
|
|
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
2024-02-15 02:28:10 -07:00
|
|
|
score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
|
2024-02-08 10:41:25 -07:00
|
|
|
|
|
|
|
return scores.scatter_add_(1, input_ids, score)
|
|
|
|
|
|
|
|
|
|
|
|
class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
|
|
|
r"""
|
|
|
|
Frequency penalty as defined by OpenAI
|
|
|
|
|
|
|
|
Args:
|
|
|
|
frequency_penalty (`List[float]`):
|
|
|
|
The parameter for frequency penalty. 0.0 means no penalty.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
|
|
|
|
self.penalty = penalty
|
|
|
|
self.penalty_tensor = torch.tensor(
|
|
|
|
penalty, dtype=dtype, device=device
|
|
|
|
).unsqueeze(1)
|
|
|
|
|
|
|
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
|
|
|
score = torch.gather(scores, 1, input_ids)
|
|
|
|
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
|
|
|
score = -torch.where(
|
|
|
|
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
|
|
|
)
|
|
|
|
|
|
|
|
return scores.scatter_add_(1, input_ids, score)
|
|
|
|
|
|
|
|
def filter(self, indices):
|
|
|
|
self.penalty = [self.penalty[i] for i in indices]
|
|
|
|
if any([x != 0.0 for x in self.penalty]):
|
|
|
|
self.penalty_tensor = self.penalty_tensor[indices]
|
|
|
|
return self
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
2023-05-26 04:30:27 -06:00
|
|
|
class HeterogeneousTemperatureLogitsWarper:
|
|
|
|
r"""
|
|
|
|
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
|
|
|
|
This version allows for a separate value for each sample and runs inplace when possible.
|
|
|
|
It doesn't validate inputs.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
temperature (`float`):
|
|
|
|
The value used to module the logits distribution.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self, temperature: List[float], dtype: torch.dtype, device: torch.device
|
|
|
|
):
|
|
|
|
self.temperature = temperature
|
|
|
|
self.temperature_tensor = torch.tensor(
|
|
|
|
temperature, dtype=dtype, device=device
|
|
|
|
).unsqueeze(1)
|
|
|
|
|
|
|
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
|
|
|
scores.div_(self.temperature_tensor)
|
|
|
|
return scores
|
|
|
|
|
|
|
|
def filter(self, indices):
|
|
|
|
self.temperature = [self.temperature[i] for i in indices]
|
|
|
|
if any([x != 1.0 for x in self.temperature]):
|
|
|
|
self.temperature_tensor = self.temperature_tensor[indices]
|
|
|
|
return self
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
|
|
|
"""
|
|
|
|
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
|
|
|
This version allows for a separate value for each sample and runs inplace when possible.
|
|
|
|
It doesn't validate inputs.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
top_p (`float`):
|
|
|
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
|
|
|
higher are kept for generation.
|
|
|
|
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
|
|
|
All filtered values will be set to this float value.
|
|
|
|
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
|
|
|
Minimum number of tokens that cannot be filtered.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
top_p: List[float],
|
|
|
|
dtype: torch.dtype,
|
|
|
|
device: torch.device,
|
|
|
|
filter_value: float = -math.inf,
|
|
|
|
min_tokens_to_keep: int = 1,
|
|
|
|
):
|
|
|
|
self.top_p = top_p
|
|
|
|
self.top_p_opposite = 1 - torch.tensor(
|
|
|
|
top_p, dtype=dtype, device=device
|
|
|
|
).unsqueeze(1)
|
|
|
|
self.filter_value = filter_value
|
|
|
|
self.min_tokens_to_keep = min_tokens_to_keep
|
|
|
|
|
|
|
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
|
|
|
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
|
|
|
|
probs = sorted_logits.softmax(dim=-1)
|
|
|
|
# This is way faster for some reason
|
|
|
|
for i in range(probs.shape[0]):
|
|
|
|
probs[i] = probs[i].cumsum(dim=-1)
|
|
|
|
|
|
|
|
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
|
|
|
sorted_indices_to_remove = probs <= self.top_p_opposite
|
2023-07-04 12:11:33 -06:00
|
|
|
# Keep at least min_tokens_to_keep
|
|
|
|
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
2023-05-26 04:30:27 -06:00
|
|
|
|
|
|
|
# scatter sorted tensors to original indexing
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
|
|
1, sorted_indices, sorted_indices_to_remove
|
|
|
|
)
|
|
|
|
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
|
|
|
|
|
|
|
return warped_scores
|
|
|
|
|
|
|
|
def filter(self, indices):
|
|
|
|
self.top_p = [self.top_p[i] for i in indices]
|
|
|
|
if any([x < 1.0 for x in self.top_p]):
|
|
|
|
self.top_p_opposite = self.top_p_opposite[indices]
|
|
|
|
return self
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
|
|
|
r"""
|
|
|
|
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
|
|
|
This version allows for a separate value for each sample and runs inplace when possible.
|
|
|
|
It doesn't validate inputs.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
top_k (`int`):
|
|
|
|
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
|
|
|
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
|
|
|
All filtered values will be set to this float value.
|
|
|
|
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
|
|
|
Minimum number of tokens that cannot be filtered.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
top_k: List[int],
|
|
|
|
device: torch.device,
|
|
|
|
filter_value: float = -math.inf,
|
|
|
|
min_tokens_to_keep: int = 1,
|
|
|
|
):
|
|
|
|
self.top_k = top_k
|
|
|
|
self.max_top_k = max(top_k)
|
|
|
|
# value - 1 as we will use top_k to index and python uses 0 based numbering
|
|
|
|
self.top_k_tensor = torch.tensor(
|
|
|
|
[max(x - 1, min_tokens_to_keep - 1) for x in top_k],
|
|
|
|
dtype=torch.int64,
|
|
|
|
device=device,
|
|
|
|
).unsqueeze(1)
|
|
|
|
|
|
|
|
# 0 is a special value that disables top_k warping for this member of the batch
|
|
|
|
disabled = [x == 0 for x in top_k]
|
|
|
|
|
|
|
|
if any(disabled):
|
|
|
|
self.top_k_disabled_mask = torch.tensor(
|
|
|
|
disabled, dtype=torch.bool, device=device
|
|
|
|
).view(-1, 1)
|
|
|
|
else:
|
|
|
|
self.top_k_disabled_mask = None
|
|
|
|
|
|
|
|
self.filter_value = filter_value
|
|
|
|
|
|
|
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
|
|
|
# If max_top_k is superior to the vocab, we need to clamp or the warper will fail
|
|
|
|
if scores.size(-1) < self.max_top_k:
|
|
|
|
max_top_k = scores.size(-1)
|
|
|
|
top_k = torch.clamp_max(self.top_k_tensor, max_top_k)
|
|
|
|
else:
|
|
|
|
max_top_k = self.max_top_k
|
|
|
|
top_k = self.top_k_tensor
|
|
|
|
|
|
|
|
# Get the kth score for each member of the batch
|
|
|
|
kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
|
|
|
|
|
|
|
|
# Mask member of kth_scores that do not want to use top_k warping
|
|
|
|
if self.top_k_disabled_mask is not None:
|
|
|
|
kth_scores.masked_fill_(self.top_k_disabled_mask, self.filter_value)
|
|
|
|
|
|
|
|
# Remove all tokens with a probability less than the last token of the top-k
|
|
|
|
indices_to_remove = scores < kth_scores
|
|
|
|
scores.masked_fill_(indices_to_remove, self.filter_value)
|
|
|
|
return scores
|
|
|
|
|
|
|
|
def filter(self, indices):
|
|
|
|
self.top_k = [self.top_k[i] for i in indices]
|
|
|
|
disabled = [x == 0 for x in self.top_k]
|
|
|
|
|
|
|
|
if not all(disabled):
|
|
|
|
self.top_k_tensor = self.top_k_tensor[indices]
|
|
|
|
self.max_top_k = max(self.top_k)
|
|
|
|
|
|
|
|
if self.top_k_disabled_mask is not None:
|
|
|
|
self.top_k_disabled_mask = (
|
|
|
|
self.top_k_disabled_mask[indices] if any(disabled) else None
|
|
|
|
)
|
|
|
|
|
|
|
|
return self
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
|
|
|
r"""
|
|
|
|
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
|
|
|
|
Generation](https://arxiv.org/abs/2202.00666) for more information.
|
|
|
|
This version allows for a separate value for each sample and runs inplace when possible.
|
|
|
|
It doesn't validate inputs.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
mass (`float`):
|
|
|
|
Value of typical_p between 0 and 1 inclusive, defaults to 0.9.
|
|
|
|
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
|
|
|
All filtered values will be set to this float value.
|
|
|
|
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
|
|
|
Minimum number of tokens that cannot be filtered.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
mass: List[float],
|
|
|
|
dtype: torch.dtype,
|
|
|
|
device: torch.device,
|
|
|
|
filter_value: float = -math.inf,
|
|
|
|
min_tokens_to_keep: int = 1,
|
|
|
|
):
|
|
|
|
self.mass = mass
|
|
|
|
self.mass_tensor = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)
|
|
|
|
|
|
|
|
# 1 is a special value that disables typical_p warping for this member of the batch
|
|
|
|
disabled = [x == 1.0 for x in mass]
|
|
|
|
|
|
|
|
if any(disabled):
|
|
|
|
self.disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device)
|
|
|
|
else:
|
|
|
|
self.disabled_mask = None
|
|
|
|
|
|
|
|
self.filter_value = filter_value
|
|
|
|
self.min_tokens_to_keep = min_tokens_to_keep
|
|
|
|
|
|
|
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
|
|
|
# calculate entropy
|
|
|
|
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
|
|
|
p = torch.exp(normalized)
|
|
|
|
ent = -(normalized * p).nansum(-1, keepdim=True)
|
|
|
|
|
|
|
|
# shift and sort
|
|
|
|
shifted_scores = torch.abs((-normalized) - ent)
|
|
|
|
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
|
|
|
|
sorted_logits = scores.gather(-1, sorted_indices)
|
|
|
|
probs = sorted_logits.softmax(dim=-1)
|
|
|
|
# This is way faster for some reason
|
|
|
|
for i in range(probs.shape[0]):
|
|
|
|
probs[i] = probs[i].cumsum(dim=-1)
|
|
|
|
|
|
|
|
# Remove tokens with cumulative mass above the threshold
|
|
|
|
last_ind = (probs < self.mass_tensor).sum(dim=1)
|
|
|
|
last_ind[last_ind < 0] = 0
|
|
|
|
|
|
|
|
if self.disabled_mask is not None:
|
|
|
|
last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)
|
|
|
|
|
|
|
|
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
|
|
|
|
1, last_ind.view(-1, 1)
|
|
|
|
)
|
|
|
|
if self.min_tokens_to_keep > 1:
|
|
|
|
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
|
|
|
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
|
|
1, sorted_indices, sorted_indices_to_remove
|
|
|
|
)
|
|
|
|
|
|
|
|
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
|
|
|
|
|
|
|
return warped_scores
|
|
|
|
|
|
|
|
def filter(self, indices):
|
|
|
|
self.mass = [self.mass[i] for i in indices]
|
|
|
|
disabled = [x == 1.0 for x in self.mass]
|
|
|
|
|
|
|
|
if not all(disabled):
|
|
|
|
self.mass_tensor = self.mass_tensor[indices]
|
|
|
|
|
|
|
|
if self.disabled_mask is not None:
|
|
|
|
self.disabled_mask = (
|
|
|
|
self.disabled_mask[indices] if any(disabled) else None
|
|
|
|
)
|
|
|
|
|
|
|
|
return self
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
class HeterogeneousProcessorWrapper(LogitsProcessor):
|
|
|
|
r"""
|
|
|
|
A wrapper for logit warpers or processors without heterogeneous parameter support.
|
|
|
|
Args:
|
|
|
|
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`):
|
|
|
|
A mapping of sample indices to logit warpers or processors, to be run sequentially.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]],
|
|
|
|
):
|
|
|
|
self.processors = processors
|
|
|
|
|
|
|
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
|
|
|
for i, processor in self.processors.items():
|
|
|
|
scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1])
|
|
|
|
return scores
|
|
|
|
|
|
|
|
def filter(self, indices):
|
|
|
|
new_processors = {}
|
|
|
|
for i, idx in enumerate(indices):
|
|
|
|
if idx in self.processors:
|
|
|
|
new_processors[i] = self.processors[idx]
|
|
|
|
|
|
|
|
if new_processors:
|
|
|
|
self.processors = new_processors
|
|
|
|
return self
|
|
|
|
return None
|
2024-02-15 02:28:10 -07:00
|
|
|
|
|
|
|
|
|
|
|
class GrammarLogitProcessor(LogitsProcessor):
|
|
|
|
fsm_state: DefaultDict[int, int]
|
|
|
|
fsm: RegexFSM
|
|
|
|
|
|
|
|
def __init__(self, tokenizer, device, grammar, grammar_type):
|
|
|
|
self.device = device
|
|
|
|
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
|
|
|
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
|
|
|
|
grammar_type, grammar, self.tokenizer
|
|
|
|
)
|
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
|
|
|
logits: torch.Tensor,
|
|
|
|
fsm_grammar_state: int,
|
|
|
|
):
|
|
|
|
if fsm_grammar_state == -1 or self.fsm is None:
|
|
|
|
return logits
|
|
|
|
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
|
2024-02-16 09:50:57 -07:00
|
|
|
mask = torch.full_like(logits, -math.inf)
|
2024-02-15 02:28:10 -07:00
|
|
|
mask[allowed_tokens] = 0
|
|
|
|
biased_scores = logits + mask
|
|
|
|
return biased_scores
|
|
|
|
|
|
|
|
def advance(self, next_token_id, fsm_grammar_state):
|
|
|
|
return GrammarLogitProcessor._advance(
|
|
|
|
next_token_id, fsm_grammar_state, self.fsm
|
|
|
|
)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _advance(next_token_id, fsm_grammar_state, fsm):
|
|
|
|
if fsm_grammar_state == -1:
|
|
|
|
return fsm_grammar_state
|
|
|
|
return fsm.next_state(fsm_grammar_state, next_token_id)
|
|
|
|
|
|
|
|
# TODO: move grammar compilation into the router
|
|
|
|
@staticmethod
|
|
|
|
@lru_cache(maxsize=32, typed=True)
|
|
|
|
def _cached_compile_fsm(grammar_type, schema, tokenizer):
|
|
|
|
start_time = time.time()
|
|
|
|
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
|
|
|
|
schema = build_regex_from_object(schema)
|
|
|
|
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
|
2024-02-16 03:58:58 -07:00
|
|
|
pass # schema is already a regex just here for clarity
|
2024-02-15 02:28:10 -07:00
|
|
|
fsm = RegexFSM(schema, tokenizer)
|
|
|
|
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
|
|
|
return fsm
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@lru_cache(maxsize=32, typed=True)
|
|
|
|
def _cached_adapt_tokenizer(tokenizer):
|
|
|
|
"""Adapt tokenizer to work with the FSM.
|
|
|
|
|
|
|
|
The API of Outlines tokenizers is slightly different to that of
|
|
|
|
`transformers`. In addition we need to handle the missing spaces to
|
|
|
|
Llama's tokenizer to be able to compile FSMs for this model.
|
|
|
|
|
|
|
|
"""
|
|
|
|
start_time = time.time()
|
|
|
|
tokenizer.vocabulary = tokenizer.get_vocab()
|
|
|
|
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
|
|
|
|
|
|
|
def convert_token_to_string(token: str) -> str:
|
|
|
|
from transformers.file_utils import SPIECE_UNDERLINE
|
|
|
|
|
|
|
|
string = tokenizer.convert_tokens_to_string([token])
|
|
|
|
|
|
|
|
# A hack to handle missing spaces to HF's Llama tokenizers
|
|
|
|
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
|
|
|
return " " + string
|
|
|
|
|
|
|
|
return string
|
|
|
|
|
|
|
|
tokenizer.convert_token_to_string = convert_token_to_string
|
|
|
|
logger.debug(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
|
|
|
|
return tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
2024-02-16 09:50:57 -07:00
|
|
|
def __init__(self, tokenizer, device, grammars, grammar_types):
|
2024-02-15 02:28:10 -07:00
|
|
|
self.device = device
|
|
|
|
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
|
|
|
self.fsms = []
|
2024-02-16 09:50:57 -07:00
|
|
|
for grammar, grammar_type in zip(grammars, grammar_types):
|
2024-02-15 02:28:10 -07:00
|
|
|
fsm = GrammarLogitProcessor._cached_compile_fsm(
|
2024-02-16 09:50:57 -07:00
|
|
|
grammar_type, grammar, self.tokenizer
|
2024-02-15 02:28:10 -07:00
|
|
|
)
|
|
|
|
self.fsms.append(fsm)
|
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
|
|
|
logits: torch.Tensor,
|
|
|
|
fsm_grammar_states: List[int],
|
|
|
|
):
|
|
|
|
mask = torch.full_like(logits, -math.inf)
|
|
|
|
for i in range(logits.shape[0]):
|
|
|
|
fsm = self.fsms[i]
|
|
|
|
if fsm_grammar_states[i] == -1 or fsm is None:
|
|
|
|
continue
|
|
|
|
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
|
|
|
mask[i, allowed_tokens] = 0
|
|
|
|
logits += mask
|
|
|
|
return logits
|
|
|
|
|
2024-02-16 09:50:57 -07:00
|
|
|
def advance_batch(self, next_token_ids, fsm_grammar_states):
|
2024-02-15 02:28:10 -07:00
|
|
|
return [
|
|
|
|
GrammarLogitProcessor._advance(
|
|
|
|
next_token_ids[i], fsm_grammar_states[i], self.fsms[i]
|
|
|
|
)
|
|
|
|
for i in range(len(next_token_ids))
|
|
|
|
]
|
|
|
|
|
|
|
|
def advance_at_index(self, next_token_id, fsm_grammar_state, index):
|
|
|
|
return GrammarLogitProcessor._advance(
|
|
|
|
next_token_id, fsm_grammar_state, self.fsms[index]
|
|
|
|
)
|
|
|
|
|
|
|
|
def filter(self, indices):
|
2024-02-16 09:50:57 -07:00
|
|
|
new_fsms = []
|
|
|
|
for i in indices:
|
|
|
|
new_fsms.append(self.fsms[i])
|
|
|
|
self.fsms = new_fsms
|
|
|
|
return self
|