feat(server): add logits watermark (#90)
This commit is contained in:
parent
f874c47831
commit
9b8ea6a6c7
|
@ -55,6 +55,10 @@ struct Args {
|
|||
otlp_endpoint: Option<String>,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Vec<String>,
|
||||
#[clap(long, env)]
|
||||
watermark_gamma: Option<f32>,
|
||||
#[clap(long, env)]
|
||||
watermark_delta: Option<f32>,
|
||||
}
|
||||
|
||||
fn main() -> ExitCode {
|
||||
|
@ -88,6 +92,8 @@ fn main() -> ExitCode {
|
|||
json_output,
|
||||
otlp_endpoint,
|
||||
cors_allow_origin,
|
||||
watermark_gamma,
|
||||
watermark_delta,
|
||||
} = args;
|
||||
|
||||
// Signal handler
|
||||
|
@ -243,6 +249,8 @@ fn main() -> ExitCode {
|
|||
huggingface_hub_cache,
|
||||
weights_cache_override,
|
||||
disable_custom_kernels,
|
||||
watermark_gamma,
|
||||
watermark_delta,
|
||||
otlp_endpoint,
|
||||
status_sender,
|
||||
shutdown,
|
||||
|
@ -414,6 +422,8 @@ fn shard_manager(
|
|||
huggingface_hub_cache: Option<String>,
|
||||
weights_cache_override: Option<String>,
|
||||
disable_custom_kernels: bool,
|
||||
watermark_gamma: Option<f32>,
|
||||
watermark_delta: Option<f32>,
|
||||
otlp_endpoint: Option<String>,
|
||||
status_sender: mpsc::Sender<ShardStatus>,
|
||||
shutdown: Arc<Mutex<bool>>,
|
||||
|
@ -494,6 +504,16 @@ fn shard_manager(
|
|||
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
|
||||
}
|
||||
|
||||
// Watermark Gamma
|
||||
if let Some(watermark_gamma) = watermark_gamma {
|
||||
env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
|
||||
}
|
||||
|
||||
// Watermark Delta
|
||||
if let Some(watermark_delta) = watermark_delta {
|
||||
env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
|
||||
}
|
||||
|
||||
// Start process
|
||||
tracing::info!("Starting shard {rank}");
|
||||
let mut p = match Popen::create(
|
||||
|
|
|
@ -40,6 +40,8 @@ message NextTokenChooserParameters {
|
|||
uint64 seed = 5;
|
||||
/// repetition penalty
|
||||
float repetition_penalty = 6;
|
||||
/// token watermarking using "A Watermark for Large Language Models"
|
||||
bool watermark = 7;
|
||||
}
|
||||
|
||||
message StoppingCriteriaParameters {
|
||||
|
|
|
@ -53,6 +53,9 @@ pub(crate) struct GenerateParameters {
|
|||
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
|
||||
pub stop: Vec<String>,
|
||||
#[serde(default)]
|
||||
#[schema(default = "false", example = true)]
|
||||
pub watermark: bool,
|
||||
#[serde(default)]
|
||||
#[schema(default = "true")]
|
||||
pub details: bool,
|
||||
#[serde(default)]
|
||||
|
@ -72,7 +75,8 @@ fn default_parameters() -> GenerateParameters {
|
|||
do_sample: false,
|
||||
max_new_tokens: default_max_new_tokens(),
|
||||
return_full_text: None,
|
||||
stop: vec![],
|
||||
stop: Vec::new(),
|
||||
watermark: false,
|
||||
details: false,
|
||||
seed: None,
|
||||
}
|
||||
|
|
|
@ -234,6 +234,7 @@ mod tests {
|
|||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 0.0,
|
||||
watermark: false
|
||||
},
|
||||
stopping_parameters: StoppingCriteriaParameters {
|
||||
max_new_tokens: 0,
|
||||
|
|
|
@ -72,6 +72,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
|||
max_new_tokens: 1,
|
||||
return_full_text: None,
|
||||
stop: Vec::new(),
|
||||
watermark: false,
|
||||
details: false,
|
||||
seed: None,
|
||||
},
|
||||
|
|
|
@ -157,6 +157,7 @@ fn validate(
|
|||
max_new_tokens,
|
||||
stop: stop_sequences,
|
||||
seed,
|
||||
watermark,
|
||||
..
|
||||
} = request.parameters;
|
||||
|
||||
|
@ -232,6 +233,7 @@ fn validate(
|
|||
top_p,
|
||||
do_sample,
|
||||
seed,
|
||||
watermark,
|
||||
};
|
||||
let stopping_parameters = StoppingCriteriaParameters {
|
||||
max_new_tokens,
|
||||
|
|
|
@ -67,7 +67,9 @@ class CausalLMBatch(Batch):
|
|||
for r in pb.requests:
|
||||
inputs.append(r.inputs)
|
||||
input_lengths.append(r.input_length)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
|
||||
)
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
|
|
|
@ -100,7 +100,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||
input_lengths.append(r.input_length)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
|
||||
)
|
||||
stopping_criterias.append(
|
||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||
)
|
||||
|
|
|
@ -77,7 +77,9 @@ class Seq2SeqLMBatch(Batch):
|
|||
# Decoder sequence only contains the bos_token
|
||||
decoder_input_ids.append(tokenizer.bos_token_id)
|
||||
decoder_input_lengths.append(1)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
|
||||
)
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
|
|
|
@ -13,6 +13,7 @@ from typing import List, Tuple, Optional
|
|||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.pb.generate_pb2 import FinishReason
|
||||
from text_generation.utils.watermark import WatermarkLogitsProcessor
|
||||
|
||||
|
||||
class Sampling:
|
||||
|
@ -35,6 +36,8 @@ class Greedy:
|
|||
class NextTokenChooser:
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size,
|
||||
watermark=False,
|
||||
temperature=1.0,
|
||||
repetition_penalty=1.0,
|
||||
top_k=None,
|
||||
|
@ -47,6 +50,11 @@ class NextTokenChooser:
|
|||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
sampling = do_sample
|
||||
|
||||
if watermark:
|
||||
warpers.append(WatermarkLogitsProcessor(vocab_size, device=device))
|
||||
if repetition_penalty is not None and repetition_penalty != 1.0:
|
||||
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
||||
if temperature is not None and temperature != 1.0:
|
||||
temperature = float(temperature)
|
||||
warpers.append(TemperatureLogitsWarper(temperature))
|
||||
|
@ -57,8 +65,6 @@ class NextTokenChooser:
|
|||
if top_p is not None and top_p < 1.0:
|
||||
warpers.append(TopPLogitsWarper(top_p=top_p))
|
||||
sampling = True
|
||||
if repetition_penalty is not None and repetition_penalty != 1.0:
|
||||
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
||||
|
||||
self.warpers = warpers
|
||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||
|
@ -77,9 +83,14 @@ class NextTokenChooser:
|
|||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
|
||||
cls,
|
||||
pb: generate_pb2.NextTokenChooserParameters,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
) -> "NextTokenChooser":
|
||||
return NextTokenChooser(
|
||||
vocab_size=vocab_size,
|
||||
watermark=pb.watermark,
|
||||
temperature=pb.temperature,
|
||||
repetition_penalty=pb.repetition_penalty,
|
||||
top_k=pb.top_k,
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 Authors of "A Watermark for Large Language Models"
|
||||
# available at https://arxiv.org/abs/2301.10226
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import torch
|
||||
from transformers import LogitsProcessor
|
||||
|
||||
GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)
|
||||
DELTA = os.getenv("WATERMARK_DELTA", 2.0)
|
||||
|
||||
|
||||
class WatermarkLogitsProcessor(LogitsProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
gamma: float = GAMMA,
|
||||
delta: float = DELTA,
|
||||
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
|
||||
device: str = "cpu",
|
||||
):
|
||||
# watermarking parameters
|
||||
self.vocab_size = vocab_size
|
||||
self.gamma = gamma
|
||||
self.delta = delta
|
||||
self.rng = torch.Generator(device=device)
|
||||
self.hash_key = hash_key
|
||||
|
||||
def _seed_rng(self, input_ids: torch.LongTensor) -> None:
|
||||
assert (
|
||||
input_ids.shape[-1] >= 1
|
||||
), "requires at least a 1 token prefix sequence to seed rng"
|
||||
prev_token = input_ids[-1].item()
|
||||
self.rng.manual_seed(self.hash_key * prev_token)
|
||||
|
||||
def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
|
||||
# seed the rng using the previous tokens/prefix
|
||||
self._seed_rng(input_ids)
|
||||
|
||||
greenlist_size = int(self.vocab_size * self.gamma)
|
||||
vocab_permutation = torch.randperm(
|
||||
self.vocab_size, device=input_ids.device, generator=self.rng
|
||||
)
|
||||
greenlist_ids = vocab_permutation[:greenlist_size]
|
||||
return greenlist_ids
|
||||
|
||||
@staticmethod
|
||||
def _calc_greenlist_mask(
|
||||
scores: torch.FloatTensor, greenlist_token_ids
|
||||
) -> torch.BoolTensor:
|
||||
green_tokens_mask = torch.zeros_like(scores)
|
||||
green_tokens_mask[-1, greenlist_token_ids] = 1
|
||||
final_mask = green_tokens_mask.bool()
|
||||
return final_mask
|
||||
|
||||
@staticmethod
|
||||
def _bias_greenlist_logits(
|
||||
scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float
|
||||
) -> torch.Tensor:
|
||||
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
|
||||
return scores
|
||||
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
assert len(input_ids) == 1
|
||||
greenlist_ids = self._get_greenlist_ids(input_ids[0])
|
||||
green_tokens_mask = self._calc_greenlist_mask(
|
||||
scores=scores, greenlist_token_ids=greenlist_ids
|
||||
)
|
||||
|
||||
scores = self._bias_greenlist_logits(
|
||||
scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta
|
||||
)
|
||||
return scores
|
Loading…
Reference in New Issue