diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 4cfc15b9..ef3f0260 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -73,7 +73,7 @@ class CausalLMBatch(Batch): inputs.append(r.inputs) input_lengths.append(r.input_length) next_token_choosers.append( - NextTokenChooser.from_pb(r.parameters, len(tokenizer), device) + NextTokenChooser.from_pb(r.parameters, device) ) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 484621d3..d04a3bce 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -103,7 +103,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): inputs.append(escape_custom_split_sequence(r.inputs)) input_lengths.append(r.input_length) next_token_choosers.append( - NextTokenChooser.from_pb(r.parameters, len(tokenizer), device) + NextTokenChooser.from_pb(r.parameters, device) ) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 73a10879..bece913a 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -83,7 +83,7 @@ class Seq2SeqLMBatch(Batch): decoder_input_ids.append(tokenizer.bos_token_id) decoder_input_lengths.append(1) next_token_choosers.append( - NextTokenChooser.from_pb(r.parameters, len(tokenizer), device) + NextTokenChooser.from_pb(r.parameters, device) ) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 41dfabd3..aa76b6eb 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -36,7 +36,6 @@ class Greedy: class NextTokenChooser: def __init__( self, - vocab_size, watermark=False, temperature=1.0, repetition_penalty=1.0, @@ -52,7 +51,7 @@ class NextTokenChooser: sampling = do_sample if watermark: - warpers.append(WatermarkLogitsProcessor(vocab_size, device=device)) + warpers.append(WatermarkLogitsProcessor(device=device)) if repetition_penalty is not None and repetition_penalty != 1.0: warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) if temperature is not None and temperature != 1.0: @@ -85,11 +84,9 @@ class NextTokenChooser: def from_pb( cls, pb: generate_pb2.NextTokenChooserParameters, - vocab_size: int, device: torch.device, ) -> "NextTokenChooser": return NextTokenChooser( - vocab_size=vocab_size, watermark=pb.watermark, temperature=pb.temperature, repetition_penalty=pb.repetition_penalty, diff --git a/server/text_generation_server/utils/watermark.py b/server/text_generation_server/utils/watermark.py index 4c42d2a1..cf6214ce 100644 --- a/server/text_generation_server/utils/watermark.py +++ b/server/text_generation_server/utils/watermark.py @@ -25,14 +25,12 @@ DELTA = os.getenv("WATERMARK_DELTA", 2.0) class WatermarkLogitsProcessor(LogitsProcessor): def __init__( self, - vocab_size: int, gamma: float = GAMMA, delta: float = DELTA, hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width device: str = "cpu", ): # watermarking parameters - self.vocab_size = vocab_size self.gamma = gamma self.delta = delta self.rng = torch.Generator(device=device) @@ -45,13 +43,13 @@ class WatermarkLogitsProcessor(LogitsProcessor): prev_token = input_ids[-1].item() self.rng.manual_seed(self.hash_key * prev_token) - def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]: + def _get_greenlist_ids(self, input_ids: torch.LongTensor, max_value: int) -> list[int]: # seed the rng using the previous tokens/prefix self._seed_rng(input_ids) - greenlist_size = int(self.vocab_size * self.gamma) + greenlist_size = int(max_value * self.gamma) vocab_permutation = torch.randperm( - self.vocab_size, device=input_ids.device, generator=self.rng + max_value, device=input_ids.device, generator=self.rng ) greenlist_ids = vocab_permutation[:greenlist_size] return greenlist_ids @@ -76,7 +74,7 @@ class WatermarkLogitsProcessor(LogitsProcessor): self, input_ids: torch.LongTensor, scores: torch.FloatTensor ) -> torch.FloatTensor: assert len(input_ids) == 1 - greenlist_ids = self._get_greenlist_ids(input_ids[0]) + greenlist_ids = self._get_greenlist_ids(input_ids[0], scores.shape[-1]) green_tokens_mask = self._calc_greenlist_mask( scores=scores, greenlist_token_ids=greenlist_ids )