feat(server): use cuda graph in logits warping (#302)

This commit is contained in:
OlivierDehaene 2023-05-10 19:08:54 +02:00 committed by GitHub
parent 35ab6cfcf1
commit a6c18c39bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 93 additions and 32 deletions

View File

@ -1,8 +1,8 @@
import re import re
import torch import torch
from functools import lru_cache
from transformers import ( from transformers import (
LogitsProcessorList,
TemperatureLogitsWarper, TemperatureLogitsWarper,
TopKLogitsWarper, TopKLogitsWarper,
TopPLogitsWarper, TopPLogitsWarper,
@ -25,8 +25,10 @@ class Sampling:
def __call__(self, logits): def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, -1) probs = torch.nn.functional.softmax(logits, -1)
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator) # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
return next_tokens q = torch.empty_like(probs).exponential_(1, generator=self.generator).div_(probs)
return q.argmax()
class Greedy: class Greedy:
@ -34,6 +36,63 @@ class Greedy:
return logits.argmax() return logits.argmax()
class StaticWarper:
def __init__(
self,
temperature=1.0,
top_k=None,
top_p=None,
typical_p=None,
):
self.warpers = []
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
self.warpers.append(TemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
self.warpers.append(TopKLogitsWarper(top_k=top_k))
if top_p is not None and top_p < 1.0:
self.warpers.append(TopPLogitsWarper(top_p=top_p))
if typical_p is not None and typical_p < 1.0:
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
self.cuda_graph = None
self.static_scores = None
self.static_warped_scores = None
self.static_next_logprob = None
def __call__(self, scores):
if self.cuda_graph is None:
self.static_scores = scores
self.cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cuda_graph):
for warper in self.warpers:
self.static_warped_scores = warper(None, self.static_scores)
# Compute logprobs
self.static_next_logprob = torch.log_softmax(
self.static_warped_scores, -1
)
self.static_scores.copy_(scores)
self.cuda_graph.replay()
return self.static_warped_scores, self.static_next_logprob
@lru_cache(10)
def static_warper(
temperature: Optional[float],
top_k: Optional[int],
top_p: Optional[float],
typical_p: Optional[float],
) -> StaticWarper:
return StaticWarper(
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
)
class NextTokenChooser: class NextTokenChooser:
def __init__( def __init__(
self, self,
@ -47,43 +106,45 @@ class NextTokenChooser:
seed=0, seed=0,
device="cpu", device="cpu",
): ):
warpers = LogitsProcessorList() self.watermark_processor = (
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files WatermarkLogitsProcessor(device=device) if watermark else None
# all samplers can be found in `generation_utils_samplers.py` )
sampling = do_sample self.repetition_processor = (
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
if repetition_penalty
else None
)
if watermark: has_warpers = (
warpers.append(WatermarkLogitsProcessor(device=device)) (temperature is not None and temperature != 1.0)
if repetition_penalty is not None and repetition_penalty != 1.0: or (top_k is not None and top_k != 0)
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) or (top_p is not None and top_p < 1.0)
if temperature is not None and temperature != 1.0: or (typical_p is not None and typical_p < 1.0)
temperature = float(temperature) )
warpers.append(TemperatureLogitsWarper(temperature)) if has_warpers:
sampling = True self.static_warper = static_warper(
if top_k is not None and top_k != 0: temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
warpers.append(TopKLogitsWarper(top_k=top_k)) )
sampling = True else:
if top_p is not None and top_p < 1.0: self.static_warper = None
warpers.append(TopPLogitsWarper(top_p=top_p))
sampling = True
if typical_p is not None and typical_p < 1.0:
warpers.append(TypicalLogitsWarper(mass=typical_p))
sampling = True
self.warpers = warpers sampling = do_sample or has_warpers
self.choice = Sampling(seed, device) if sampling else Greedy() self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
# Warp logits if self.watermark_processor:
scores = self.warpers(input_ids, scores) scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor:
scores = self.repetition_processor(input_ids, scores)
# Compute logprobs if self.static_warper is None:
logprobs = torch.log_softmax(scores, -1) next_logprob = torch.log_softmax(scores, -1)
else:
scores, next_logprob = self.static_warper(scores)
# Choose tokens next_id = self.choice(scores[-1]).view(1, 1)
next_id = self.choice(scores[-1])
return next_id.view(1, 1), logprobs return next_id, next_logprob
@classmethod @classmethod
def from_pb( def from_pb(