fix(server): fix index out of range for watermarking (#110)
This commit is contained in:
parent
2c5df5d2af
commit
941cd42e0c
|
@ -73,7 +73,7 @@ class CausalLMBatch(Batch):
|
|||
inputs.append(r.inputs)
|
||||
input_lengths.append(r.input_length)
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
|
||||
NextTokenChooser.from_pb(r.parameters, device)
|
||||
)
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
|
|
|
@ -103,7 +103,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||
input_lengths.append(r.input_length)
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
|
||||
NextTokenChooser.from_pb(r.parameters, device)
|
||||
)
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
|
|
|
@ -83,7 +83,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
decoder_input_ids.append(tokenizer.bos_token_id)
|
||||
decoder_input_lengths.append(1)
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
|
||||
NextTokenChooser.from_pb(r.parameters, device)
|
||||
)
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
|
|
|
@ -36,7 +36,6 @@ class Greedy:
|
|||
class NextTokenChooser:
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size,
|
||||
watermark=False,
|
||||
temperature=1.0,
|
||||
repetition_penalty=1.0,
|
||||
|
@ -52,7 +51,7 @@ class NextTokenChooser:
|
|||
sampling = do_sample
|
||||
|
||||
if watermark:
|
||||
warpers.append(WatermarkLogitsProcessor(vocab_size, device=device))
|
||||
warpers.append(WatermarkLogitsProcessor(device=device))
|
||||
if repetition_penalty is not None and repetition_penalty != 1.0:
|
||||
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
||||
if temperature is not None and temperature != 1.0:
|
||||
|
@ -85,11 +84,9 @@ class NextTokenChooser:
|
|||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.NextTokenChooserParameters,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
) -> "NextTokenChooser":
|
||||
return NextTokenChooser(
|
||||
vocab_size=vocab_size,
|
||||
watermark=pb.watermark,
|
||||
temperature=pb.temperature,
|
||||
repetition_penalty=pb.repetition_penalty,
|
||||
|
|
|
@ -25,14 +25,12 @@ DELTA = os.getenv("WATERMARK_DELTA", 2.0)
|
|||
class WatermarkLogitsProcessor(LogitsProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
gamma: float = GAMMA,
|
||||
delta: float = DELTA,
|
||||
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
|
||||
device: str = "cpu",
|
||||
):
|
||||
# watermarking parameters
|
||||
self.vocab_size = vocab_size
|
||||
self.gamma = gamma
|
||||
self.delta = delta
|
||||
self.rng = torch.Generator(device=device)
|
||||
|
@ -45,13 +43,13 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
|||
prev_token = input_ids[-1].item()
|
||||
self.rng.manual_seed(self.hash_key * prev_token)
|
||||
|
||||
def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
|
||||
def _get_greenlist_ids(self, input_ids: torch.LongTensor, max_value: int) -> list[int]:
|
||||
# seed the rng using the previous tokens/prefix
|
||||
self._seed_rng(input_ids)
|
||||
|
||||
greenlist_size = int(self.vocab_size * self.gamma)
|
||||
greenlist_size = int(max_value * self.gamma)
|
||||
vocab_permutation = torch.randperm(
|
||||
self.vocab_size, device=input_ids.device, generator=self.rng
|
||||
max_value, device=input_ids.device, generator=self.rng
|
||||
)
|
||||
greenlist_ids = vocab_permutation[:greenlist_size]
|
||||
return greenlist_ids
|
||||
|
@ -76,7 +74,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
|||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
assert len(input_ids) == 1
|
||||
greenlist_ids = self._get_greenlist_ids(input_ids[0])
|
||||
greenlist_ids = self._get_greenlist_ids(input_ids[0], scores.shape[-1])
|
||||
green_tokens_mask = self._calc_greenlist_mask(
|
||||
scores=scores, greenlist_token_ids=greenlist_ids
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue