From 9b205d33cc349a96937204ad16ebc1a578ad619b Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Mon, 6 Mar 2023 13:22:58 +0100 Subject: [PATCH] fix(server): fix generate_stream by forcing tokens to be decoded correctly (#100) --- launcher/tests/mt0_base.json | 22 ++++++++++----------- server/tests/models/test_seq2seq_lm.py | 2 +- server/text_generation/models/causal_lm.py | 4 +--- server/text_generation/models/model.py | 18 +++++++++++++++++ server/text_generation/models/seq2seq_lm.py | 8 ++++---- server/text_generation/utils/watermark.py | 20 +++++++++---------- 6 files changed, 45 insertions(+), 29 deletions(-) diff --git a/launcher/tests/mt0_base.json b/launcher/tests/mt0_base.json index c06a2c26..22c9499f 100644 --- a/launcher/tests/mt0_base.json +++ b/launcher/tests/mt0_base.json @@ -14,7 +14,7 @@ "tokens": [ { "id": 259, - "text": "", + "text": " ", "logprob": -1.3656927, "special": false }, @@ -32,13 +32,13 @@ }, { "id": 287, - "text": "the", + "text": " the", "logprob": -1.2102449, "special": false }, { "id": 259, - "text": "", + "text": " ", "logprob": -1.6057279, "special": false }, @@ -50,19 +50,19 @@ }, { "id": 304, - "text": "of", + "text": " of", "logprob": -0.5270343, "special": false }, { "id": 287, - "text": "the", + "text": " the", "logprob": -0.62522805, "special": false }, { "id": 259, - "text": "", + "text": " ", "logprob": -1.4069618, "special": false }, @@ -74,19 +74,19 @@ }, { "id": 304, - "text": "of", + "text": " of", "logprob": -1.3172221, "special": false }, { "id": 287, - "text": "the", + "text": " the", "logprob": -0.3501925, "special": false }, { "id": 259, - "text": "", + "text": " ", "logprob": -0.7219573, "special": false }, @@ -104,7 +104,7 @@ }, { "id": 259, - "text": "", + "text": " ", "logprob": -0.32933083, "special": false }, @@ -116,7 +116,7 @@ }, { "id": 2978, - "text": "test", + "text": " test", "logprob": -1.5846587, "special": false }, diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index c6eacba7..f7173392 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -148,7 +148,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch) assert all([generation.generated_text is None for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations]) assert all([generation.token_id.item() == 259 for generation in generations]) - assert all([generation.token_text == "" for generation in generations]) + assert all([generation.token_text == " " for generation in generations]) assert generations[0].request_id == 0 diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 36d419e3..23c94ddf 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -385,10 +385,8 @@ class CausalLM(Model): # Generated token next_token_logprob = logprobs[-1, next_token_id] next_token_id_squeezed = next_token_id.squeeze() - next_token_text = self.tokenizer.decode( + next_token_text = self.decode_token( next_token_id_squeezed, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, ) # Evaluate stopping criteria diff --git a/server/text_generation/models/model.py b/server/text_generation/models/model.py index 52480526..09fa6a2a 100644 --- a/server/text_generation/models/model.py +++ b/server/text_generation/models/model.py @@ -15,6 +15,15 @@ class Model(ABC): self.all_special_ids = set(tokenizer.all_special_ids) self.device = device + # see `decode_token` method + self.tokenizer.add_special_tokens( + {"additional_special_tokens": [""]} + ) + self.special_decode_token_id = self.tokenizer.convert_tokens_to_ids( + "" + ) + self.special_decode_token_length = len("") + @property @abstractmethod def batch_type(self) -> Type[B]: @@ -23,3 +32,12 @@ class Model(ABC): @abstractmethod def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: raise NotImplementedError + + def decode_token(self, token_id: int) -> str: + """Hack to hopefully support generate_stream for the maximum number of tokenizers""" + # append token to special decode token and decode both + result = self.tokenizer.decode( + [self.special_decode_token_id, token_id], skip_special_tokens=False + ) + # slice to remove special decode token + return result[self.special_decode_token_length :] diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 38089967..4b88baec 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -342,7 +342,9 @@ class Seq2SeqLM(Model): return Seq2SeqLMBatch def decode(self, decoder_ids: List[int]) -> str: - return self.tokenizer.decode(decoder_ids, skip_special_tokens=True) + return self.tokenizer.decode( + decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) def forward( self, @@ -457,10 +459,8 @@ class Seq2SeqLM(Model): # Generated token next_token_logprob = logprobs[-1, next_token_id] next_token_id_squeezed = next_token_id.squeeze() - next_token_text = self.tokenizer.decode( + next_token_text = self.decode_token( next_token_id_squeezed, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, ) # Evaluate stopping criteria diff --git a/server/text_generation/utils/watermark.py b/server/text_generation/utils/watermark.py index 6f5664fe..4c42d2a1 100644 --- a/server/text_generation/utils/watermark.py +++ b/server/text_generation/utils/watermark.py @@ -24,12 +24,12 @@ DELTA = os.getenv("WATERMARK_DELTA", 2.0) class WatermarkLogitsProcessor(LogitsProcessor): def __init__( - self, - vocab_size: int, - gamma: float = GAMMA, - delta: float = DELTA, - hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width - device: str = "cpu", + self, + vocab_size: int, + gamma: float = GAMMA, + delta: float = DELTA, + hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width + device: str = "cpu", ): # watermarking parameters self.vocab_size = vocab_size @@ -40,7 +40,7 @@ class WatermarkLogitsProcessor(LogitsProcessor): def _seed_rng(self, input_ids: torch.LongTensor) -> None: assert ( - input_ids.shape[-1] >= 1 + input_ids.shape[-1] >= 1 ), "requires at least a 1 token prefix sequence to seed rng" prev_token = input_ids[-1].item() self.rng.manual_seed(self.hash_key * prev_token) @@ -58,7 +58,7 @@ class WatermarkLogitsProcessor(LogitsProcessor): @staticmethod def _calc_greenlist_mask( - scores: torch.FloatTensor, greenlist_token_ids + scores: torch.FloatTensor, greenlist_token_ids ) -> torch.BoolTensor: green_tokens_mask = torch.zeros_like(scores) green_tokens_mask[-1, greenlist_token_ids] = 1 @@ -67,13 +67,13 @@ class WatermarkLogitsProcessor(LogitsProcessor): @staticmethod def _bias_greenlist_logits( - scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float + scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float ) -> torch.Tensor: scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias return scores def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor + self, input_ids: torch.LongTensor, scores: torch.FloatTensor ) -> torch.FloatTensor: assert len(input_ids) == 1 greenlist_ids = self._get_greenlist_ids(input_ids[0])