feat(server): support repetition penalty (#47)

This commit is contained in:
OlivierDehaene 2023-02-01 15:58:42 +01:00 committed by GitHub
parent 2ad895a6cc
commit 313194f6d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 30 additions and 3 deletions

View File

@ -15,7 +15,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) - Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- [Safetensors](https://github.com/huggingface/safetensors) weight loading - [Safetensors](https://github.com/huggingface/safetensors) weight loading
- 45ms per token generation for BLOOM with 8xA100 80GB - 45ms per token generation for BLOOM with 8xA100 80GB
- Logits warpers (temperature scaling, topk ...) - Logits warpers (temperature scaling, topk, repetition penalty ...)
- Stop sequences - Stop sequences
- Log probabilities - Log probabilities

View File

@ -38,6 +38,8 @@ message NextTokenChooserParameters {
bool do_sample = 4; bool do_sample = 4;
/// random seed for sampling /// random seed for sampling
uint64 seed = 5; uint64 seed = 5;
/// repetition penalty
float repetition_penalty = 6;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {

View File

@ -13,6 +13,8 @@ use validation::Validation;
pub(crate) struct GenerateParameters { pub(crate) struct GenerateParameters {
#[serde(default = "default_temperature")] #[serde(default = "default_temperature")]
pub temperature: f32, pub temperature: f32,
#[serde(default = "default_repetition_penalty")]
pub repetition_penalty: f32,
#[serde(default = "default_top_k")] #[serde(default = "default_top_k")]
pub top_k: i32, pub top_k: i32,
#[serde(default = "default_top_p")] #[serde(default = "default_top_p")]
@ -32,6 +34,9 @@ pub(crate) struct GenerateParameters {
fn default_temperature() -> f32 { fn default_temperature() -> f32 {
1.0 1.0
} }
fn default_repetition_penalty() -> f32 {
1.0
}
fn default_top_k() -> i32 { fn default_top_k() -> i32 {
0 0
@ -52,6 +57,7 @@ fn default_max_new_tokens() -> u32 {
fn default_parameters() -> GenerateParameters { fn default_parameters() -> GenerateParameters {
GenerateParameters { GenerateParameters {
temperature: default_temperature(), temperature: default_temperature(),
repetition_penalty: default_repetition_penalty(),
top_k: default_top_k(), top_k: default_top_k(),
top_p: default_top_p(), top_p: default_top_p(),
do_sample: default_do_sample(), do_sample: default_do_sample(),

View File

@ -33,6 +33,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
inputs: "liveness".to_string(), inputs: "liveness".to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
temperature: 1.0, temperature: 1.0,
repetition_penalty: 1.0,
top_k: 0, top_k: 0,
top_p: 1.0, top_p: 1.0,
do_sample: false, do_sample: false,

View File

@ -113,6 +113,9 @@ fn validate(
if request.parameters.temperature <= 0.0 { if request.parameters.temperature <= 0.0 {
return Err(ValidationError::Temperature); return Err(ValidationError::Temperature);
} }
if request.parameters.repetition_penalty <= 0.0 {
return Err(ValidationError::RepetitionPenalty);
}
if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 { if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
return Err(ValidationError::TopP); return Err(ValidationError::TopP);
} }
@ -146,6 +149,7 @@ fn validate(
// Return ValidGenerateRequest // Return ValidGenerateRequest
let GenerateParameters { let GenerateParameters {
temperature, temperature,
repetition_penalty,
top_k, top_k,
top_p, top_p,
do_sample, do_sample,
@ -156,6 +160,7 @@ fn validate(
let parameters = NextTokenChooserParameters { let parameters = NextTokenChooserParameters {
temperature, temperature,
repetition_penalty,
top_k: top_k as u32, top_k: top_k as u32,
top_p, top_p,
do_sample, do_sample,
@ -195,6 +200,8 @@ pub(crate) struct ValidGenerateRequest {
pub enum ValidationError { pub enum ValidationError {
#[error("temperature must be strictly positive")] #[error("temperature must be strictly positive")]
Temperature, Temperature,
#[error("repetition_penalty must be strictly positive")]
RepetitionPenalty,
#[error("top_p must be > 0.0 and <= 1.0")] #[error("top_p must be > 0.0 and <= 1.0")]
TopP, TopP,
#[error("top_k must be strictly positive")] #[error("top_k must be strictly positive")]

View File

@ -7,6 +7,7 @@ from text_generation.pb import generate_pb2
def default_pb_parameters(): def default_pb_parameters():
return generate_pb2.NextTokenChooserParameters( return generate_pb2.NextTokenChooserParameters(
temperature=1.0, temperature=1.0,
repetition_penalty=1.0,
top_k=0, top_k=0,
top_p=1.0, top_p=1.0,
do_sample=False, do_sample=False,

View File

@ -336,7 +336,7 @@ class CausalLM(Model):
all_input_ids, all_input_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
tokens, logprobs = next_token_chooser(all_input_ids, logits) tokens, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits)
next_token_id = tokens[-1].view(1, 1) next_token_id = tokens[-1].view(1, 1)
# Append next token to all tokens # Append next token to all tokens

View File

@ -418,7 +418,9 @@ class Seq2SeqLM(Model):
decoder_input_ids, decoder_input_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits) next_token_id, logprobs = next_token_chooser(
decoder_input_ids.view(1, -1), logits
)
# Append next token to decoder tokens # Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id]) decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])

View File

@ -17,6 +17,7 @@ from typing import List, Optional, Tuple
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.generation.logits_process import ( from transformers.generation.logits_process import (
LogitsProcessorList, LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper, TemperatureLogitsWarper,
TopPLogitsWarper, TopPLogitsWarper,
TopKLogitsWarper, TopKLogitsWarper,
@ -48,6 +49,7 @@ class NextTokenChooser:
def __init__( def __init__(
self, self,
temperature=1.0, temperature=1.0,
repetition_penalty=1.0,
top_k=None, top_k=None,
top_p=None, top_p=None,
do_sample=False, do_sample=False,
@ -68,6 +70,9 @@ class NextTokenChooser:
if top_p is not None and top_p < 1.0: if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p)) warpers.append(TopPLogitsWarper(top_p=top_p))
sampling = True sampling = True
if repetition_penalty is not None and repetition_penalty != 1.0:
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
sampling = True
self.warpers = warpers self.warpers = warpers
self.choice = Sampling(seed, device) if sampling else Greedy() self.choice = Sampling(seed, device) if sampling else Greedy()
@ -75,8 +80,10 @@ class NextTokenChooser:
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
# Warp logits # Warp logits
scores = self.warpers(input_ids, scores) scores = self.warpers(input_ids, scores)
# Compute logprobs # Compute logprobs
logprobs = torch.log_softmax(scores, -1) logprobs = torch.log_softmax(scores, -1)
# Choose tokens # Choose tokens
next_ids = self.choice(scores) next_ids = self.choice(scores)
return next_ids, logprobs return next_ids, logprobs
@ -87,6 +94,7 @@ class NextTokenChooser:
) -> "NextTokenChooser": ) -> "NextTokenChooser":
return NextTokenChooser( return NextTokenChooser(
temperature=pb.temperature, temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,
top_k=pb.top_k, top_k=pb.top_k,
top_p=pb.top_p, top_p=pb.top_p,
do_sample=pb.do_sample, do_sample=pb.do_sample,