From 53aa9194c8c070afd19fa4660305dab2b280adf3 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 20 Jun 2023 11:06:10 +0200 Subject: [PATCH] fix(server): fix warpers on CPU (#472) Closes #471 --- .../text_generation_server/models/__init__.py | 20 ++++------- .../utils/logits_process.py | 36 +++++++++++-------- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 9540d99e..3fdc23b2 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -237,20 +237,12 @@ def get_model( ) elif model_type == "t5": - if sharded: - return T5Sharded( - model_id, - revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) - else: - return Seq2SeqLM( - model_id, - revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) + return T5Sharded( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) if sharded: raise ValueError("sharded is not supported for AutoModel") diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index faa94516..0cbbf8b0 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -42,25 +42,31 @@ class StaticWarper: self.static_next_logprob = None def __call__(self, scores): - if self.cuda_graph is None: - self.static_scores = scores - self.cuda_graph = torch.cuda.CUDAGraph() + if torch.cuda.is_available(): + if self.cuda_graph is None: + self.static_scores = scores + self.cuda_graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self.cuda_graph, pool=mempool): - local_scores = self.static_scores - for warper in self.warpers: - local_scores = warper(None, local_scores) + with torch.cuda.graph(self.cuda_graph, pool=mempool): + local_scores = self.static_scores + for warper in self.warpers: + local_scores = warper(None, local_scores) - self.static_warped_scores = local_scores - # Compute logprobs - self.static_next_logprob = torch.log_softmax( - self.static_warped_scores, -1 - ) + self.static_warped_scores = local_scores + # Compute logprobs + self.static_next_logprob = torch.log_softmax( + self.static_warped_scores, -1 + ) - self.static_scores.copy_(scores) - self.cuda_graph.replay() + self.static_scores.copy_(scores) + self.cuda_graph.replay() - return self.static_warped_scores, self.static_next_logprob + return self.static_warped_scores, self.static_next_logprob + + # CPU branch + for warper in self.warpers: + scores = warper(None, scores) + return scores, torch.log_softmax(scores, -1) @lru_cache(10)