feat(server): use cuda graph in logits warping (#302)
This commit is contained in:
parent
35ab6cfcf1
commit
a6c18c39bb
|
@ -1,8 +1,8 @@
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
from transformers import (
|
from transformers import (
|
||||||
LogitsProcessorList,
|
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
|
@ -25,8 +25,10 @@ class Sampling:
|
||||||
|
|
||||||
def __call__(self, logits):
|
def __call__(self, logits):
|
||||||
probs = torch.nn.functional.softmax(logits, -1)
|
probs = torch.nn.functional.softmax(logits, -1)
|
||||||
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
|
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
||||||
return next_tokens
|
q = torch.empty_like(probs).exponential_(1, generator=self.generator).div_(probs)
|
||||||
|
|
||||||
|
return q.argmax()
|
||||||
|
|
||||||
|
|
||||||
class Greedy:
|
class Greedy:
|
||||||
|
@ -34,6 +36,63 @@ class Greedy:
|
||||||
return logits.argmax()
|
return logits.argmax()
|
||||||
|
|
||||||
|
|
||||||
|
class StaticWarper:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=None,
|
||||||
|
top_p=None,
|
||||||
|
typical_p=None,
|
||||||
|
):
|
||||||
|
self.warpers = []
|
||||||
|
|
||||||
|
if temperature is not None and temperature != 1.0:
|
||||||
|
temperature = float(temperature)
|
||||||
|
self.warpers.append(TemperatureLogitsWarper(temperature))
|
||||||
|
if top_k is not None and top_k != 0:
|
||||||
|
self.warpers.append(TopKLogitsWarper(top_k=top_k))
|
||||||
|
if top_p is not None and top_p < 1.0:
|
||||||
|
self.warpers.append(TopPLogitsWarper(top_p=top_p))
|
||||||
|
if typical_p is not None and typical_p < 1.0:
|
||||||
|
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
|
||||||
|
|
||||||
|
self.cuda_graph = None
|
||||||
|
self.static_scores = None
|
||||||
|
self.static_warped_scores = None
|
||||||
|
self.static_next_logprob = None
|
||||||
|
|
||||||
|
def __call__(self, scores):
|
||||||
|
if self.cuda_graph is None:
|
||||||
|
self.static_scores = scores
|
||||||
|
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||||
|
|
||||||
|
with torch.cuda.graph(self.cuda_graph):
|
||||||
|
for warper in self.warpers:
|
||||||
|
self.static_warped_scores = warper(None, self.static_scores)
|
||||||
|
|
||||||
|
# Compute logprobs
|
||||||
|
self.static_next_logprob = torch.log_softmax(
|
||||||
|
self.static_warped_scores, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.static_scores.copy_(scores)
|
||||||
|
self.cuda_graph.replay()
|
||||||
|
|
||||||
|
return self.static_warped_scores, self.static_next_logprob
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(10)
|
||||||
|
def static_warper(
|
||||||
|
temperature: Optional[float],
|
||||||
|
top_k: Optional[int],
|
||||||
|
top_p: Optional[float],
|
||||||
|
typical_p: Optional[float],
|
||||||
|
) -> StaticWarper:
|
||||||
|
return StaticWarper(
|
||||||
|
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -47,43 +106,45 @@ class NextTokenChooser:
|
||||||
seed=0,
|
seed=0,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
):
|
):
|
||||||
warpers = LogitsProcessorList()
|
self.watermark_processor = (
|
||||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||||
# all samplers can be found in `generation_utils_samplers.py`
|
)
|
||||||
sampling = do_sample
|
self.repetition_processor = (
|
||||||
|
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
|
||||||
|
if repetition_penalty
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
if watermark:
|
has_warpers = (
|
||||||
warpers.append(WatermarkLogitsProcessor(device=device))
|
(temperature is not None and temperature != 1.0)
|
||||||
if repetition_penalty is not None and repetition_penalty != 1.0:
|
or (top_k is not None and top_k != 0)
|
||||||
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
or (top_p is not None and top_p < 1.0)
|
||||||
if temperature is not None and temperature != 1.0:
|
or (typical_p is not None and typical_p < 1.0)
|
||||||
temperature = float(temperature)
|
)
|
||||||
warpers.append(TemperatureLogitsWarper(temperature))
|
if has_warpers:
|
||||||
sampling = True
|
self.static_warper = static_warper(
|
||||||
if top_k is not None and top_k != 0:
|
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
||||||
warpers.append(TopKLogitsWarper(top_k=top_k))
|
)
|
||||||
sampling = True
|
else:
|
||||||
if top_p is not None and top_p < 1.0:
|
self.static_warper = None
|
||||||
warpers.append(TopPLogitsWarper(top_p=top_p))
|
|
||||||
sampling = True
|
|
||||||
if typical_p is not None and typical_p < 1.0:
|
|
||||||
warpers.append(TypicalLogitsWarper(mass=typical_p))
|
|
||||||
sampling = True
|
|
||||||
|
|
||||||
self.warpers = warpers
|
sampling = do_sample or has_warpers
|
||||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
def __call__(self, input_ids, scores):
|
||||||
# Warp logits
|
if self.watermark_processor:
|
||||||
scores = self.warpers(input_ids, scores)
|
scores = self.watermark_processor(input_ids, scores)
|
||||||
|
if self.repetition_processor:
|
||||||
|
scores = self.repetition_processor(input_ids, scores)
|
||||||
|
|
||||||
# Compute logprobs
|
if self.static_warper is None:
|
||||||
logprobs = torch.log_softmax(scores, -1)
|
next_logprob = torch.log_softmax(scores, -1)
|
||||||
|
else:
|
||||||
|
scores, next_logprob = self.static_warper(scores)
|
||||||
|
|
||||||
# Choose tokens
|
next_id = self.choice(scores[-1]).view(1, 1)
|
||||||
next_id = self.choice(scores[-1])
|
|
||||||
|
|
||||||
return next_id.view(1, 1), logprobs
|
return next_id, next_logprob
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
|
|
Loading…
Reference in New Issue