From 5fa8ae041cef2b5f5587d4eb076dbaeb5bf992f6 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 12 Apr 2023 12:03:10 +0200 Subject: [PATCH] feat(server): optimize decode for sane tokenizers (#170) --- benchmark/Cargo.lock | 4 +- server/text_generation_server/models/bloom.py | 8 ++- .../models/causal_lm.py | 11 ++-- .../models/flash_causal_lm.py | 6 +-- .../models/flash_santacoder.py | 3 +- server/text_generation_server/models/model.py | 51 ++++++++++++++----- .../models/santacoder.py | 3 +- .../models/seq2seq_lm.py | 13 +++-- 8 files changed, 67 insertions(+), 32 deletions(-) diff --git a/benchmark/Cargo.lock b/benchmark/Cargo.lock index cb9c7533..74610932 100644 --- a/benchmark/Cargo.lock +++ b/benchmark/Cargo.lock @@ -853,7 +853,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "grpc-metadata" -version = "0.4.1" +version = "0.1.0" dependencies = [ "opentelemetry", "tonic", @@ -2140,7 +2140,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "0.4.3" +version = "0.5.0" dependencies = [ "futures", "grpc-metadata", diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 1a961027..efcc9e05 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -49,6 +49,11 @@ class BloomCausalLMBatch(CausalLMBatch): class BLOOM(CausalLM): + def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + super(BLOOM, self).__init__( + model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1 + ) + @property def batch_type(self) -> Type[CausalLMBatch]: return BloomCausalLMBatch @@ -94,8 +99,7 @@ class BLOOMSharded(BLOOM): self.model = model.eval().to(dtype) torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( - tokenizer=tokenizer, - device=device, + tokenizer=tokenizer, device=device, decode_buffer=1 ) @staticmethod diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 8c092d6a..6347b1a5 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -291,7 +291,13 @@ class CausalLMBatch(Batch): class CausalLM(Model): - def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: bool = False, + decode_buffer: int = 3, + ): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 @@ -319,8 +325,7 @@ class CausalLM(Model): ) super(CausalLM, self).__init__( - tokenizer=tokenizer, - device=device, + tokenizer=tokenizer, device=device, decode_buffer=decode_buffer ) @property diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 3801ed24..507fec0f 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -212,7 +212,8 @@ class FlashCausalLM(Model): model_cls: Type[PreTrainedModel], model_id: str, revision: Optional[str] = None, - quantize=False, + quantize: bool = False, + decode_buffer: int = 3, ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -237,8 +238,7 @@ class FlashCausalLM(Model): ) super(FlashCausalLM, self).__init__( - tokenizer=tokenizer, - device=device, + tokenizer=tokenizer, device=device, decode_buffer=decode_buffer ) @property diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 2f680995..e10d259d 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -62,8 +62,7 @@ class FlashSantacoder(FlashCausalLM): self.model = model.eval().to(device).to(dtype) super(FlashCausalLM, self).__init__( - tokenizer=tokenizer, - device=device, + tokenizer=tokenizer, device=device, decode_buffer=1 ) @staticmethod diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 5b82872c..08a48553 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -10,10 +10,19 @@ B = TypeVar("B", bound=Batch) class Model(ABC): - def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device): + def __init__( + self, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, + decode_buffer: int = 3, + ): + if decode_buffer < 1: + raise ValueError("decode_buffer must be >= 1") + self.tokenizer = tokenizer self.all_special_ids = set(tokenizer.all_special_ids) self.device = device + self.decode_buffer = decode_buffer @property @abstractmethod @@ -39,23 +48,37 @@ class Model(ABC): ) if token_offset is None: - token_offset = len(all_input_ids) - 3 + token_offset = len(all_input_ids) - self.decode_buffer + # left token buffer + if self.decode_buffer > 1: + # Decode token_offset token minus last one and token_offset tokens + raw_texts = self.tokenizer.batch_decode( + [all_input_ids[token_offset:-1], all_input_ids[token_offset:]], + skip_special_tokens=False, + ) - # Decode token_offset token minus last one and token_offset tokens - results = self.tokenizer.batch_decode( - [all_input_ids[token_offset:-1], all_input_ids[token_offset:]], - skip_special_tokens=False, - ) - - # default offset is only the last token - if offset is None: - offset = len(results[0]) + # default offset is only the last token + offset = len(raw_texts[0]) + sequence_text = raw_texts[1] + else: + # Only decode the last token without using a token buffer + sequence_text = self.tokenizer.decode( + all_input_ids[-1], skip_special_tokens=False + ) + # no offset in this case + offset = 0 + else: + assert offset is not None + sequence_text = self.tokenizer.decode( + all_input_ids[token_offset:], + skip_special_tokens=False, + ) # get text - text = results[1][offset:] + token_text = sequence_text[offset:] # if text is utf-8 - if text and text[-1] != "�": - return text, None, None + if token_text and token_text[-1] != "�": + return token_text, None, None else: return "", offset, token_offset diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index 58361a8d..8646a4e1 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -54,8 +54,7 @@ class SantaCoder(CausalLM): ) super(CausalLM, self).__init__( - tokenizer=tokenizer, - device=device, + tokenizer=tokenizer, device=device, decode_buffer=1 ) def decode(self, generated_ids: List[int]) -> str: diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 13eafd62..82b5def0 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -330,7 +330,13 @@ class Seq2SeqLMBatch(Batch): class Seq2SeqLM(Model): - def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: bool = False, + decode_buffer: int = 3, + ): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 @@ -354,8 +360,7 @@ class Seq2SeqLM(Model): tokenizer.bos_token_id = self.model.config.decoder_start_token_id super(Seq2SeqLM, self).__init__( - tokenizer=tokenizer, - device=device, + tokenizer=tokenizer, device=device, decode_buffer=decode_buffer ) @property @@ -496,7 +501,7 @@ class Seq2SeqLM(Model): if stop: # Slice with decoder_input_length to remove padding # Decode all tokens - output_text = self.decode(decoder_input_ids[-new_decoder_input_length:]) + output_text = self.decode(decoder_input_ids[-decoder_input_length:]) # Get seed if isinstance(next_token_chooser.choice, Sampling):