fix(server): fix index out of range for watermarking (#110)

This commit is contained in:
OlivierDehaene 2023-03-08 18:29:08 +01:00 committed by GitHub
parent 2c5df5d2af
commit 941cd42e0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 8 additions and 13 deletions

View File

@ -73,7 +73,7 @@ class CausalLMBatch(Batch):
inputs.append(r.inputs)
input_lengths.append(r.input_length)
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
NextTokenChooser.from_pb(r.parameters, device)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer

View File

@ -103,7 +103,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
inputs.append(escape_custom_split_sequence(r.inputs))
input_lengths.append(r.input_length)
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
NextTokenChooser.from_pb(r.parameters, device)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer

View File

@ -83,7 +83,7 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids.append(tokenizer.bos_token_id)
decoder_input_lengths.append(1)
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
NextTokenChooser.from_pb(r.parameters, device)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer

View File

@ -36,7 +36,6 @@ class Greedy:
class NextTokenChooser:
def __init__(
self,
vocab_size,
watermark=False,
temperature=1.0,
repetition_penalty=1.0,
@ -52,7 +51,7 @@ class NextTokenChooser:
sampling = do_sample
if watermark:
warpers.append(WatermarkLogitsProcessor(vocab_size, device=device))
warpers.append(WatermarkLogitsProcessor(device=device))
if repetition_penalty is not None and repetition_penalty != 1.0:
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
if temperature is not None and temperature != 1.0:
@ -85,11 +84,9 @@ class NextTokenChooser:
def from_pb(
cls,
pb: generate_pb2.NextTokenChooserParameters,
vocab_size: int,
device: torch.device,
) -> "NextTokenChooser":
return NextTokenChooser(
vocab_size=vocab_size,
watermark=pb.watermark,
temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,

View File

@ -25,14 +25,12 @@ DELTA = os.getenv("WATERMARK_DELTA", 2.0)
class WatermarkLogitsProcessor(LogitsProcessor):
def __init__(
self,
vocab_size: int,
gamma: float = GAMMA,
delta: float = DELTA,
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
device: str = "cpu",
):
# watermarking parameters
self.vocab_size = vocab_size
self.gamma = gamma
self.delta = delta
self.rng = torch.Generator(device=device)
@ -45,13 +43,13 @@ class WatermarkLogitsProcessor(LogitsProcessor):
prev_token = input_ids[-1].item()
self.rng.manual_seed(self.hash_key * prev_token)
def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
def _get_greenlist_ids(self, input_ids: torch.LongTensor, max_value: int) -> list[int]:
# seed the rng using the previous tokens/prefix
self._seed_rng(input_ids)
greenlist_size = int(self.vocab_size * self.gamma)
greenlist_size = int(max_value * self.gamma)
vocab_permutation = torch.randperm(
self.vocab_size, device=input_ids.device, generator=self.rng
max_value, device=input_ids.device, generator=self.rng
)
greenlist_ids = vocab_permutation[:greenlist_size]
return greenlist_ids
@ -76,7 +74,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
assert len(input_ids) == 1
greenlist_ids = self._get_greenlist_ids(input_ids[0])
greenlist_ids = self._get_greenlist_ids(input_ids[0], scores.shape[-1])
green_tokens_mask = self._calc_greenlist_mask(
scores=scores, greenlist_token_ids=greenlist_ids
)