feat(server): add logits watermark (#90)
This commit is contained in:
parent
f874c47831
commit
9b8ea6a6c7
|
@ -55,6 +55,10 @@ struct Args {
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
cors_allow_origin: Vec<String>,
|
cors_allow_origin: Vec<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
watermark_gamma: Option<f32>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
watermark_delta: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> ExitCode {
|
fn main() -> ExitCode {
|
||||||
|
@ -88,6 +92,8 @@ fn main() -> ExitCode {
|
||||||
json_output,
|
json_output,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
|
watermark_gamma,
|
||||||
|
watermark_delta,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Signal handler
|
// Signal handler
|
||||||
|
@ -243,6 +249,8 @@ fn main() -> ExitCode {
|
||||||
huggingface_hub_cache,
|
huggingface_hub_cache,
|
||||||
weights_cache_override,
|
weights_cache_override,
|
||||||
disable_custom_kernels,
|
disable_custom_kernels,
|
||||||
|
watermark_gamma,
|
||||||
|
watermark_delta,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
status_sender,
|
status_sender,
|
||||||
shutdown,
|
shutdown,
|
||||||
|
@ -414,6 +422,8 @@ fn shard_manager(
|
||||||
huggingface_hub_cache: Option<String>,
|
huggingface_hub_cache: Option<String>,
|
||||||
weights_cache_override: Option<String>,
|
weights_cache_override: Option<String>,
|
||||||
disable_custom_kernels: bool,
|
disable_custom_kernels: bool,
|
||||||
|
watermark_gamma: Option<f32>,
|
||||||
|
watermark_delta: Option<f32>,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
shutdown: Arc<Mutex<bool>>,
|
shutdown: Arc<Mutex<bool>>,
|
||||||
|
@ -494,6 +504,16 @@ fn shard_manager(
|
||||||
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
|
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Watermark Gamma
|
||||||
|
if let Some(watermark_gamma) = watermark_gamma {
|
||||||
|
env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Watermark Delta
|
||||||
|
if let Some(watermark_delta) = watermark_delta {
|
||||||
|
env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
|
||||||
|
}
|
||||||
|
|
||||||
// Start process
|
// Start process
|
||||||
tracing::info!("Starting shard {rank}");
|
tracing::info!("Starting shard {rank}");
|
||||||
let mut p = match Popen::create(
|
let mut p = match Popen::create(
|
||||||
|
|
|
@ -40,6 +40,8 @@ message NextTokenChooserParameters {
|
||||||
uint64 seed = 5;
|
uint64 seed = 5;
|
||||||
/// repetition penalty
|
/// repetition penalty
|
||||||
float repetition_penalty = 6;
|
float repetition_penalty = 6;
|
||||||
|
/// token watermarking using "A Watermark for Large Language Models"
|
||||||
|
bool watermark = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StoppingCriteriaParameters {
|
message StoppingCriteriaParameters {
|
||||||
|
|
|
@ -53,6 +53,9 @@ pub(crate) struct GenerateParameters {
|
||||||
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
|
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
|
||||||
pub stop: Vec<String>,
|
pub stop: Vec<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(default = "false", example = true)]
|
||||||
|
pub watermark: bool,
|
||||||
|
#[serde(default)]
|
||||||
#[schema(default = "true")]
|
#[schema(default = "true")]
|
||||||
pub details: bool,
|
pub details: bool,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
@ -72,7 +75,8 @@ fn default_parameters() -> GenerateParameters {
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
max_new_tokens: default_max_new_tokens(),
|
max_new_tokens: default_max_new_tokens(),
|
||||||
return_full_text: None,
|
return_full_text: None,
|
||||||
stop: vec![],
|
stop: Vec::new(),
|
||||||
|
watermark: false,
|
||||||
details: false,
|
details: false,
|
||||||
seed: None,
|
seed: None,
|
||||||
}
|
}
|
||||||
|
|
|
@ -234,6 +234,7 @@ mod tests {
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
seed: 0,
|
seed: 0,
|
||||||
repetition_penalty: 0.0,
|
repetition_penalty: 0.0,
|
||||||
|
watermark: false
|
||||||
},
|
},
|
||||||
stopping_parameters: StoppingCriteriaParameters {
|
stopping_parameters: StoppingCriteriaParameters {
|
||||||
max_new_tokens: 0,
|
max_new_tokens: 0,
|
||||||
|
|
|
@ -72,6 +72,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
return_full_text: None,
|
return_full_text: None,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
|
watermark: false,
|
||||||
details: false,
|
details: false,
|
||||||
seed: None,
|
seed: None,
|
||||||
},
|
},
|
||||||
|
|
|
@ -157,6 +157,7 @@ fn validate(
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
stop: stop_sequences,
|
stop: stop_sequences,
|
||||||
seed,
|
seed,
|
||||||
|
watermark,
|
||||||
..
|
..
|
||||||
} = request.parameters;
|
} = request.parameters;
|
||||||
|
|
||||||
|
@ -232,6 +233,7 @@ fn validate(
|
||||||
top_p,
|
top_p,
|
||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
|
watermark,
|
||||||
};
|
};
|
||||||
let stopping_parameters = StoppingCriteriaParameters {
|
let stopping_parameters = StoppingCriteriaParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
|
|
|
@ -67,7 +67,9 @@ class CausalLMBatch(Batch):
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
input_lengths.append(r.input_length)
|
input_lengths.append(r.input_length)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(
|
||||||
|
NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
|
||||||
|
)
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
|
|
@ -100,7 +100,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||||
input_lengths.append(r.input_length)
|
input_lengths.append(r.input_length)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(
|
||||||
|
NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
|
||||||
|
)
|
||||||
stopping_criterias.append(
|
stopping_criterias.append(
|
||||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||||
)
|
)
|
||||||
|
|
|
@ -77,7 +77,9 @@ class Seq2SeqLMBatch(Batch):
|
||||||
# Decoder sequence only contains the bos_token
|
# Decoder sequence only contains the bos_token
|
||||||
decoder_input_ids.append(tokenizer.bos_token_id)
|
decoder_input_ids.append(tokenizer.bos_token_id)
|
||||||
decoder_input_lengths.append(1)
|
decoder_input_lengths.append(1)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(
|
||||||
|
NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
|
||||||
|
)
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,6 +13,7 @@ from typing import List, Tuple, Optional
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
from text_generation.pb.generate_pb2 import FinishReason
|
from text_generation.pb.generate_pb2 import FinishReason
|
||||||
|
from text_generation.utils.watermark import WatermarkLogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
class Sampling:
|
class Sampling:
|
||||||
|
@ -35,6 +36,8 @@ class Greedy:
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
vocab_size,
|
||||||
|
watermark=False,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
repetition_penalty=1.0,
|
repetition_penalty=1.0,
|
||||||
top_k=None,
|
top_k=None,
|
||||||
|
@ -47,6 +50,11 @@ class NextTokenChooser:
|
||||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||||
# all samplers can be found in `generation_utils_samplers.py`
|
# all samplers can be found in `generation_utils_samplers.py`
|
||||||
sampling = do_sample
|
sampling = do_sample
|
||||||
|
|
||||||
|
if watermark:
|
||||||
|
warpers.append(WatermarkLogitsProcessor(vocab_size, device=device))
|
||||||
|
if repetition_penalty is not None and repetition_penalty != 1.0:
|
||||||
|
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
||||||
if temperature is not None and temperature != 1.0:
|
if temperature is not None and temperature != 1.0:
|
||||||
temperature = float(temperature)
|
temperature = float(temperature)
|
||||||
warpers.append(TemperatureLogitsWarper(temperature))
|
warpers.append(TemperatureLogitsWarper(temperature))
|
||||||
|
@ -57,8 +65,6 @@ class NextTokenChooser:
|
||||||
if top_p is not None and top_p < 1.0:
|
if top_p is not None and top_p < 1.0:
|
||||||
warpers.append(TopPLogitsWarper(top_p=top_p))
|
warpers.append(TopPLogitsWarper(top_p=top_p))
|
||||||
sampling = True
|
sampling = True
|
||||||
if repetition_penalty is not None and repetition_penalty != 1.0:
|
|
||||||
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
|
||||||
|
|
||||||
self.warpers = warpers
|
self.warpers = warpers
|
||||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||||
|
@ -77,9 +83,14 @@ class NextTokenChooser:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
|
cls,
|
||||||
|
pb: generate_pb2.NextTokenChooserParameters,
|
||||||
|
vocab_size: int,
|
||||||
|
device: torch.device,
|
||||||
) -> "NextTokenChooser":
|
) -> "NextTokenChooser":
|
||||||
return NextTokenChooser(
|
return NextTokenChooser(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
watermark=pb.watermark,
|
||||||
temperature=pb.temperature,
|
temperature=pb.temperature,
|
||||||
repetition_penalty=pb.repetition_penalty,
|
repetition_penalty=pb.repetition_penalty,
|
||||||
top_k=pb.top_k,
|
top_k=pb.top_k,
|
||||||
|
|
|
@ -0,0 +1,87 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 Authors of "A Watermark for Large Language Models"
|
||||||
|
# available at https://arxiv.org/abs/2301.10226
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import LogitsProcessor
|
||||||
|
|
||||||
|
GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)
|
||||||
|
DELTA = os.getenv("WATERMARK_DELTA", 2.0)
|
||||||
|
|
||||||
|
|
||||||
|
class WatermarkLogitsProcessor(LogitsProcessor):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int,
|
||||||
|
gamma: float = GAMMA,
|
||||||
|
delta: float = DELTA,
|
||||||
|
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
|
||||||
|
device: str = "cpu",
|
||||||
|
):
|
||||||
|
# watermarking parameters
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.gamma = gamma
|
||||||
|
self.delta = delta
|
||||||
|
self.rng = torch.Generator(device=device)
|
||||||
|
self.hash_key = hash_key
|
||||||
|
|
||||||
|
def _seed_rng(self, input_ids: torch.LongTensor) -> None:
|
||||||
|
assert (
|
||||||
|
input_ids.shape[-1] >= 1
|
||||||
|
), "requires at least a 1 token prefix sequence to seed rng"
|
||||||
|
prev_token = input_ids[-1].item()
|
||||||
|
self.rng.manual_seed(self.hash_key * prev_token)
|
||||||
|
|
||||||
|
def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
|
||||||
|
# seed the rng using the previous tokens/prefix
|
||||||
|
self._seed_rng(input_ids)
|
||||||
|
|
||||||
|
greenlist_size = int(self.vocab_size * self.gamma)
|
||||||
|
vocab_permutation = torch.randperm(
|
||||||
|
self.vocab_size, device=input_ids.device, generator=self.rng
|
||||||
|
)
|
||||||
|
greenlist_ids = vocab_permutation[:greenlist_size]
|
||||||
|
return greenlist_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _calc_greenlist_mask(
|
||||||
|
scores: torch.FloatTensor, greenlist_token_ids
|
||||||
|
) -> torch.BoolTensor:
|
||||||
|
green_tokens_mask = torch.zeros_like(scores)
|
||||||
|
green_tokens_mask[-1, greenlist_token_ids] = 1
|
||||||
|
final_mask = green_tokens_mask.bool()
|
||||||
|
return final_mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _bias_greenlist_logits(
|
||||||
|
scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
assert len(input_ids) == 1
|
||||||
|
greenlist_ids = self._get_greenlist_ids(input_ids[0])
|
||||||
|
green_tokens_mask = self._calc_greenlist_mask(
|
||||||
|
scores=scores, greenlist_token_ids=greenlist_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
scores = self._bias_greenlist_logits(
|
||||||
|
scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta
|
||||||
|
)
|
||||||
|
return scores
|
Loading…
Reference in New Issue