From 313194f6d75215bc0a8d1f78ef352dbc705a48a5 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 1 Feb 2023 15:58:42 +0100 Subject: [PATCH] feat(server): support repetition penalty (#47) --- README.md | 2 +- proto/generate.proto | 2 ++ router/src/lib.rs | 6 ++++++ router/src/server.rs | 1 + router/src/validation.rs | 7 +++++++ server/tests/conftest.py | 1 + server/text_generation/models/causal_lm.py | 2 +- server/text_generation/models/seq2seq_lm.py | 4 +++- server/text_generation/utils.py | 8 ++++++++ 9 files changed, 30 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index d54a092a..d092781a 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets. - Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) - [Safetensors](https://github.com/huggingface/safetensors) weight loading - 45ms per token generation for BLOOM with 8xA100 80GB -- Logits warpers (temperature scaling, topk ...) +- Logits warpers (temperature scaling, topk, repetition penalty ...) - Stop sequences - Log probabilities diff --git a/proto/generate.proto b/proto/generate.proto index 8f431c5c..098df9c5 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -38,6 +38,8 @@ message NextTokenChooserParameters { bool do_sample = 4; /// random seed for sampling uint64 seed = 5; + /// repetition penalty + float repetition_penalty = 6; } message StoppingCriteriaParameters { diff --git a/router/src/lib.rs b/router/src/lib.rs index beab7138..5b96485f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -13,6 +13,8 @@ use validation::Validation; pub(crate) struct GenerateParameters { #[serde(default = "default_temperature")] pub temperature: f32, + #[serde(default = "default_repetition_penalty")] + pub repetition_penalty: f32, #[serde(default = "default_top_k")] pub top_k: i32, #[serde(default = "default_top_p")] @@ -32,6 +34,9 @@ pub(crate) struct GenerateParameters { fn default_temperature() -> f32 { 1.0 } +fn default_repetition_penalty() -> f32 { + 1.0 +} fn default_top_k() -> i32 { 0 @@ -52,6 +57,7 @@ fn default_max_new_tokens() -> u32 { fn default_parameters() -> GenerateParameters { GenerateParameters { temperature: default_temperature(), + repetition_penalty: default_repetition_penalty(), top_k: default_top_k(), top_p: default_top_p(), do_sample: default_do_sample(), diff --git a/router/src/server.rs b/router/src/server.rs index ef3782d6..c31ca6ce 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -33,6 +33,7 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json 1.0 { return Err(ValidationError::TopP); } @@ -146,6 +149,7 @@ fn validate( // Return ValidGenerateRequest let GenerateParameters { temperature, + repetition_penalty, top_k, top_p, do_sample, @@ -156,6 +160,7 @@ fn validate( let parameters = NextTokenChooserParameters { temperature, + repetition_penalty, top_k: top_k as u32, top_p, do_sample, @@ -195,6 +200,8 @@ pub(crate) struct ValidGenerateRequest { pub enum ValidationError { #[error("temperature must be strictly positive")] Temperature, + #[error("repetition_penalty must be strictly positive")] + RepetitionPenalty, #[error("top_p must be > 0.0 and <= 1.0")] TopP, #[error("top_k must be strictly positive")] diff --git a/server/tests/conftest.py b/server/tests/conftest.py index e0ed76b4..9fae8ee1 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -7,6 +7,7 @@ from text_generation.pb import generate_pb2 def default_pb_parameters(): return generate_pb2.NextTokenChooserParameters( temperature=1.0, + repetition_penalty=1.0, top_k=0, top_p=1.0, do_sample=False, diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 4dc834b8..994c57d5 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -336,7 +336,7 @@ class CausalLM(Model): all_input_ids, ) in enumerate(iterator): # Select next token - tokens, logprobs = next_token_chooser(all_input_ids, logits) + tokens, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits) next_token_id = tokens[-1].view(1, 1) # Append next token to all tokens diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 29492dd7..1ae266d8 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -418,7 +418,9 @@ class Seq2SeqLM(Model): decoder_input_ids, ) in enumerate(iterator): # Select next token - next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits) + next_token_id, logprobs = next_token_chooser( + decoder_input_ids.view(1, -1), logits + ) # Append next token to decoder tokens decoder_input_ids = torch.cat([decoder_input_ids, next_token_id]) diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index 3b07ef3f..91e6b7b7 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -17,6 +17,7 @@ from typing import List, Optional, Tuple from transformers import PreTrainedTokenizerBase from transformers.generation.logits_process import ( LogitsProcessorList, + RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopPLogitsWarper, TopKLogitsWarper, @@ -48,6 +49,7 @@ class NextTokenChooser: def __init__( self, temperature=1.0, + repetition_penalty=1.0, top_k=None, top_p=None, do_sample=False, @@ -68,6 +70,9 @@ class NextTokenChooser: if top_p is not None and top_p < 1.0: warpers.append(TopPLogitsWarper(top_p=top_p)) sampling = True + if repetition_penalty is not None and repetition_penalty != 1.0: + warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) + sampling = True self.warpers = warpers self.choice = Sampling(seed, device) if sampling else Greedy() @@ -75,8 +80,10 @@ class NextTokenChooser: def __call__(self, input_ids, scores): # Warp logits scores = self.warpers(input_ids, scores) + # Compute logprobs logprobs = torch.log_softmax(scores, -1) + # Choose tokens next_ids = self.choice(scores) return next_ids, logprobs @@ -87,6 +94,7 @@ class NextTokenChooser: ) -> "NextTokenChooser": return NextTokenChooser( temperature=pb.temperature, + repetition_penalty=pb.repetition_penalty, top_k=pb.top_k, top_p=pb.top_p, do_sample=pb.do_sample,