From a6c18c39bb9431e6d3aba407f1d30eb43d82c04d Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 10 May 2023 19:08:54 +0200 Subject: [PATCH] feat(server): use cuda graph in logits warping (#302) --- server/text_generation_server/utils/tokens.py | 125 +++++++++++++----- 1 file changed, 93 insertions(+), 32 deletions(-) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index d5a7717..c3477ee 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,8 +1,8 @@ import re import torch +from functools import lru_cache from transformers import ( - LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, @@ -25,8 +25,10 @@ class Sampling: def __call__(self, logits): probs = torch.nn.functional.softmax(logits, -1) - next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator) - return next_tokens + # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637 + q = torch.empty_like(probs).exponential_(1, generator=self.generator).div_(probs) + + return q.argmax() class Greedy: @@ -34,6 +36,63 @@ class Greedy: return logits.argmax() +class StaticWarper: + def __init__( + self, + temperature=1.0, + top_k=None, + top_p=None, + typical_p=None, + ): + self.warpers = [] + + if temperature is not None and temperature != 1.0: + temperature = float(temperature) + self.warpers.append(TemperatureLogitsWarper(temperature)) + if top_k is not None and top_k != 0: + self.warpers.append(TopKLogitsWarper(top_k=top_k)) + if top_p is not None and top_p < 1.0: + self.warpers.append(TopPLogitsWarper(top_p=top_p)) + if typical_p is not None and typical_p < 1.0: + self.warpers.append(TypicalLogitsWarper(mass=typical_p)) + + self.cuda_graph = None + self.static_scores = None + self.static_warped_scores = None + self.static_next_logprob = None + + def __call__(self, scores): + if self.cuda_graph is None: + self.static_scores = scores + self.cuda_graph = torch.cuda.CUDAGraph() + + with torch.cuda.graph(self.cuda_graph): + for warper in self.warpers: + self.static_warped_scores = warper(None, self.static_scores) + + # Compute logprobs + self.static_next_logprob = torch.log_softmax( + self.static_warped_scores, -1 + ) + + self.static_scores.copy_(scores) + self.cuda_graph.replay() + + return self.static_warped_scores, self.static_next_logprob + + +@lru_cache(10) +def static_warper( + temperature: Optional[float], + top_k: Optional[int], + top_p: Optional[float], + typical_p: Optional[float], +) -> StaticWarper: + return StaticWarper( + temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p + ) + + class NextTokenChooser: def __init__( self, @@ -47,43 +106,45 @@ class NextTokenChooser: seed=0, device="cpu", ): - warpers = LogitsProcessorList() - # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files - # all samplers can be found in `generation_utils_samplers.py` - sampling = do_sample + self.watermark_processor = ( + WatermarkLogitsProcessor(device=device) if watermark else None + ) + self.repetition_processor = ( + RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) + if repetition_penalty + else None + ) - if watermark: - warpers.append(WatermarkLogitsProcessor(device=device)) - if repetition_penalty is not None and repetition_penalty != 1.0: - warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) - if temperature is not None and temperature != 1.0: - temperature = float(temperature) - warpers.append(TemperatureLogitsWarper(temperature)) - sampling = True - if top_k is not None and top_k != 0: - warpers.append(TopKLogitsWarper(top_k=top_k)) - sampling = True - if top_p is not None and top_p < 1.0: - warpers.append(TopPLogitsWarper(top_p=top_p)) - sampling = True - if typical_p is not None and typical_p < 1.0: - warpers.append(TypicalLogitsWarper(mass=typical_p)) - sampling = True + has_warpers = ( + (temperature is not None and temperature != 1.0) + or (top_k is not None and top_k != 0) + or (top_p is not None and top_p < 1.0) + or (typical_p is not None and typical_p < 1.0) + ) + if has_warpers: + self.static_warper = static_warper( + temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p + ) + else: + self.static_warper = None - self.warpers = warpers + sampling = do_sample or has_warpers self.choice = Sampling(seed, device) if sampling else Greedy() def __call__(self, input_ids, scores): - # Warp logits - scores = self.warpers(input_ids, scores) + if self.watermark_processor: + scores = self.watermark_processor(input_ids, scores) + if self.repetition_processor: + scores = self.repetition_processor(input_ids, scores) - # Compute logprobs - logprobs = torch.log_softmax(scores, -1) + if self.static_warper is None: + next_logprob = torch.log_softmax(scores, -1) + else: + scores, next_logprob = self.static_warper(scores) - # Choose tokens - next_id = self.choice(scores[-1]) + next_id = self.choice(scores[-1]).view(1, 1) - return next_id.view(1, 1), logprobs + return next_id, next_logprob @classmethod def from_pb(