feat(server): support repetition penalty (#47)

This commit is contained in:
OlivierDehaene 2023-02-01 15:58:42 +01:00 committed by GitHub
parent 2ad895a6cc
commit 313194f6d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 30 additions and 3 deletions

View File

@ -15,7 +15,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
- 45ms per token generation for BLOOM with 8xA100 80GB
- Logits warpers (temperature scaling, topk ...)
- Logits warpers (temperature scaling, topk, repetition penalty ...)
- Stop sequences
- Log probabilities

View File

@ -38,6 +38,8 @@ message NextTokenChooserParameters {
bool do_sample = 4;
/// random seed for sampling
uint64 seed = 5;
/// repetition penalty
float repetition_penalty = 6;
}
message StoppingCriteriaParameters {

View File

@ -13,6 +13,8 @@ use validation::Validation;
pub(crate) struct GenerateParameters {
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_repetition_penalty")]
pub repetition_penalty: f32,
#[serde(default = "default_top_k")]
pub top_k: i32,
#[serde(default = "default_top_p")]
@ -32,6 +34,9 @@ pub(crate) struct GenerateParameters {
fn default_temperature() -> f32 {
1.0
}
fn default_repetition_penalty() -> f32 {
1.0
}
fn default_top_k() -> i32 {
0
@ -52,6 +57,7 @@ fn default_max_new_tokens() -> u32 {
fn default_parameters() -> GenerateParameters {
GenerateParameters {
temperature: default_temperature(),
repetition_penalty: default_repetition_penalty(),
top_k: default_top_k(),
top_p: default_top_p(),
do_sample: default_do_sample(),

View File

@ -33,6 +33,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
inputs: "liveness".to_string(),
parameters: GenerateParameters {
temperature: 1.0,
repetition_penalty: 1.0,
top_k: 0,
top_p: 1.0,
do_sample: false,

View File

@ -113,6 +113,9 @@ fn validate(
if request.parameters.temperature <= 0.0 {
return Err(ValidationError::Temperature);
}
if request.parameters.repetition_penalty <= 0.0 {
return Err(ValidationError::RepetitionPenalty);
}
if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
return Err(ValidationError::TopP);
}
@ -146,6 +149,7 @@ fn validate(
// Return ValidGenerateRequest
let GenerateParameters {
temperature,
repetition_penalty,
top_k,
top_p,
do_sample,
@ -156,6 +160,7 @@ fn validate(
let parameters = NextTokenChooserParameters {
temperature,
repetition_penalty,
top_k: top_k as u32,
top_p,
do_sample,
@ -195,6 +200,8 @@ pub(crate) struct ValidGenerateRequest {
pub enum ValidationError {
#[error("temperature must be strictly positive")]
Temperature,
#[error("repetition_penalty must be strictly positive")]
RepetitionPenalty,
#[error("top_p must be > 0.0 and <= 1.0")]
TopP,
#[error("top_k must be strictly positive")]

View File

@ -7,6 +7,7 @@ from text_generation.pb import generate_pb2
def default_pb_parameters():
return generate_pb2.NextTokenChooserParameters(
temperature=1.0,
repetition_penalty=1.0,
top_k=0,
top_p=1.0,
do_sample=False,

View File

@ -336,7 +336,7 @@ class CausalLM(Model):
all_input_ids,
) in enumerate(iterator):
# Select next token
tokens, logprobs = next_token_chooser(all_input_ids, logits)
tokens, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits)
next_token_id = tokens[-1].view(1, 1)
# Append next token to all tokens

View File

@ -418,7 +418,9 @@ class Seq2SeqLM(Model):
decoder_input_ids,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits)
next_token_id, logprobs = next_token_chooser(
decoder_input_ids.view(1, -1), logits
)
# Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])

View File

@ -17,6 +17,7 @@ from typing import List, Optional, Tuple
from transformers import PreTrainedTokenizerBase
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopPLogitsWarper,
TopKLogitsWarper,
@ -48,6 +49,7 @@ class NextTokenChooser:
def __init__(
self,
temperature=1.0,
repetition_penalty=1.0,
top_k=None,
top_p=None,
do_sample=False,
@ -68,6 +70,9 @@ class NextTokenChooser:
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p))
sampling = True
if repetition_penalty is not None and repetition_penalty != 1.0:
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
sampling = True
self.warpers = warpers
self.choice = Sampling(seed, device) if sampling else Greedy()
@ -75,8 +80,10 @@ class NextTokenChooser:
def __call__(self, input_ids, scores):
# Warp logits
scores = self.warpers(input_ids, scores)
# Compute logprobs
logprobs = torch.log_softmax(scores, -1)
# Choose tokens
next_ids = self.choice(scores)
return next_ids, logprobs
@ -87,6 +94,7 @@ class NextTokenChooser:
) -> "NextTokenChooser":
return NextTokenChooser(
temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,
top_k=pb.top_k,
top_p=pb.top_p,
do_sample=pb.do_sample,