feat(server): support repetition penalty (#47)
This commit is contained in:
parent
2ad895a6cc
commit
313194f6d7
|
@ -15,7 +15,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
|
||||||
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
||||||
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
||||||
- 45ms per token generation for BLOOM with 8xA100 80GB
|
- 45ms per token generation for BLOOM with 8xA100 80GB
|
||||||
- Logits warpers (temperature scaling, topk ...)
|
- Logits warpers (temperature scaling, topk, repetition penalty ...)
|
||||||
- Stop sequences
|
- Stop sequences
|
||||||
- Log probabilities
|
- Log probabilities
|
||||||
|
|
||||||
|
|
|
@ -38,6 +38,8 @@ message NextTokenChooserParameters {
|
||||||
bool do_sample = 4;
|
bool do_sample = 4;
|
||||||
/// random seed for sampling
|
/// random seed for sampling
|
||||||
uint64 seed = 5;
|
uint64 seed = 5;
|
||||||
|
/// repetition penalty
|
||||||
|
float repetition_penalty = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StoppingCriteriaParameters {
|
message StoppingCriteriaParameters {
|
||||||
|
|
|
@ -13,6 +13,8 @@ use validation::Validation;
|
||||||
pub(crate) struct GenerateParameters {
|
pub(crate) struct GenerateParameters {
|
||||||
#[serde(default = "default_temperature")]
|
#[serde(default = "default_temperature")]
|
||||||
pub temperature: f32,
|
pub temperature: f32,
|
||||||
|
#[serde(default = "default_repetition_penalty")]
|
||||||
|
pub repetition_penalty: f32,
|
||||||
#[serde(default = "default_top_k")]
|
#[serde(default = "default_top_k")]
|
||||||
pub top_k: i32,
|
pub top_k: i32,
|
||||||
#[serde(default = "default_top_p")]
|
#[serde(default = "default_top_p")]
|
||||||
|
@ -32,6 +34,9 @@ pub(crate) struct GenerateParameters {
|
||||||
fn default_temperature() -> f32 {
|
fn default_temperature() -> f32 {
|
||||||
1.0
|
1.0
|
||||||
}
|
}
|
||||||
|
fn default_repetition_penalty() -> f32 {
|
||||||
|
1.0
|
||||||
|
}
|
||||||
|
|
||||||
fn default_top_k() -> i32 {
|
fn default_top_k() -> i32 {
|
||||||
0
|
0
|
||||||
|
@ -52,6 +57,7 @@ fn default_max_new_tokens() -> u32 {
|
||||||
fn default_parameters() -> GenerateParameters {
|
fn default_parameters() -> GenerateParameters {
|
||||||
GenerateParameters {
|
GenerateParameters {
|
||||||
temperature: default_temperature(),
|
temperature: default_temperature(),
|
||||||
|
repetition_penalty: default_repetition_penalty(),
|
||||||
top_k: default_top_k(),
|
top_k: default_top_k(),
|
||||||
top_p: default_top_p(),
|
top_p: default_top_p(),
|
||||||
do_sample: default_do_sample(),
|
do_sample: default_do_sample(),
|
||||||
|
|
|
@ -33,6 +33,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
||||||
inputs: "liveness".to_string(),
|
inputs: "liveness".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
|
repetition_penalty: 1.0,
|
||||||
top_k: 0,
|
top_k: 0,
|
||||||
top_p: 1.0,
|
top_p: 1.0,
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
|
|
|
@ -113,6 +113,9 @@ fn validate(
|
||||||
if request.parameters.temperature <= 0.0 {
|
if request.parameters.temperature <= 0.0 {
|
||||||
return Err(ValidationError::Temperature);
|
return Err(ValidationError::Temperature);
|
||||||
}
|
}
|
||||||
|
if request.parameters.repetition_penalty <= 0.0 {
|
||||||
|
return Err(ValidationError::RepetitionPenalty);
|
||||||
|
}
|
||||||
if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
|
if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
|
||||||
return Err(ValidationError::TopP);
|
return Err(ValidationError::TopP);
|
||||||
}
|
}
|
||||||
|
@ -146,6 +149,7 @@ fn validate(
|
||||||
// Return ValidGenerateRequest
|
// Return ValidGenerateRequest
|
||||||
let GenerateParameters {
|
let GenerateParameters {
|
||||||
temperature,
|
temperature,
|
||||||
|
repetition_penalty,
|
||||||
top_k,
|
top_k,
|
||||||
top_p,
|
top_p,
|
||||||
do_sample,
|
do_sample,
|
||||||
|
@ -156,6 +160,7 @@ fn validate(
|
||||||
|
|
||||||
let parameters = NextTokenChooserParameters {
|
let parameters = NextTokenChooserParameters {
|
||||||
temperature,
|
temperature,
|
||||||
|
repetition_penalty,
|
||||||
top_k: top_k as u32,
|
top_k: top_k as u32,
|
||||||
top_p,
|
top_p,
|
||||||
do_sample,
|
do_sample,
|
||||||
|
@ -195,6 +200,8 @@ pub(crate) struct ValidGenerateRequest {
|
||||||
pub enum ValidationError {
|
pub enum ValidationError {
|
||||||
#[error("temperature must be strictly positive")]
|
#[error("temperature must be strictly positive")]
|
||||||
Temperature,
|
Temperature,
|
||||||
|
#[error("repetition_penalty must be strictly positive")]
|
||||||
|
RepetitionPenalty,
|
||||||
#[error("top_p must be > 0.0 and <= 1.0")]
|
#[error("top_p must be > 0.0 and <= 1.0")]
|
||||||
TopP,
|
TopP,
|
||||||
#[error("top_k must be strictly positive")]
|
#[error("top_k must be strictly positive")]
|
||||||
|
|
|
@ -7,6 +7,7 @@ from text_generation.pb import generate_pb2
|
||||||
def default_pb_parameters():
|
def default_pb_parameters():
|
||||||
return generate_pb2.NextTokenChooserParameters(
|
return generate_pb2.NextTokenChooserParameters(
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
|
repetition_penalty=1.0,
|
||||||
top_k=0,
|
top_k=0,
|
||||||
top_p=1.0,
|
top_p=1.0,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
|
|
|
@ -336,7 +336,7 @@ class CausalLM(Model):
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Select next token
|
# Select next token
|
||||||
tokens, logprobs = next_token_chooser(all_input_ids, logits)
|
tokens, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits)
|
||||||
next_token_id = tokens[-1].view(1, 1)
|
next_token_id = tokens[-1].view(1, 1)
|
||||||
|
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
|
|
|
@ -418,7 +418,9 @@ class Seq2SeqLM(Model):
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Select next token
|
# Select next token
|
||||||
next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits)
|
next_token_id, logprobs = next_token_chooser(
|
||||||
|
decoder_input_ids.view(1, -1), logits
|
||||||
|
)
|
||||||
|
|
||||||
# Append next token to decoder tokens
|
# Append next token to decoder tokens
|
||||||
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
|
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
|
||||||
|
|
|
@ -17,6 +17,7 @@ from typing import List, Optional, Tuple
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.generation.logits_process import (
|
from transformers.generation.logits_process import (
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
|
RepetitionPenaltyLogitsProcessor,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
|
@ -48,6 +49,7 @@ class NextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
|
repetition_penalty=1.0,
|
||||||
top_k=None,
|
top_k=None,
|
||||||
top_p=None,
|
top_p=None,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
|
@ -68,6 +70,9 @@ class NextTokenChooser:
|
||||||
if top_p is not None and top_p < 1.0:
|
if top_p is not None and top_p < 1.0:
|
||||||
warpers.append(TopPLogitsWarper(top_p=top_p))
|
warpers.append(TopPLogitsWarper(top_p=top_p))
|
||||||
sampling = True
|
sampling = True
|
||||||
|
if repetition_penalty is not None and repetition_penalty != 1.0:
|
||||||
|
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
||||||
|
sampling = True
|
||||||
|
|
||||||
self.warpers = warpers
|
self.warpers = warpers
|
||||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||||
|
@ -75,8 +80,10 @@ class NextTokenChooser:
|
||||||
def __call__(self, input_ids, scores):
|
def __call__(self, input_ids, scores):
|
||||||
# Warp logits
|
# Warp logits
|
||||||
scores = self.warpers(input_ids, scores)
|
scores = self.warpers(input_ids, scores)
|
||||||
|
|
||||||
# Compute logprobs
|
# Compute logprobs
|
||||||
logprobs = torch.log_softmax(scores, -1)
|
logprobs = torch.log_softmax(scores, -1)
|
||||||
|
|
||||||
# Choose tokens
|
# Choose tokens
|
||||||
next_ids = self.choice(scores)
|
next_ids = self.choice(scores)
|
||||||
return next_ids, logprobs
|
return next_ids, logprobs
|
||||||
|
@ -87,6 +94,7 @@ class NextTokenChooser:
|
||||||
) -> "NextTokenChooser":
|
) -> "NextTokenChooser":
|
||||||
return NextTokenChooser(
|
return NextTokenChooser(
|
||||||
temperature=pb.temperature,
|
temperature=pb.temperature,
|
||||||
|
repetition_penalty=pb.repetition_penalty,
|
||||||
top_k=pb.top_k,
|
top_k=pb.top_k,
|
||||||
top_p=pb.top_p,
|
top_p=pb.top_p,
|
||||||
do_sample=pb.do_sample,
|
do_sample=pb.do_sample,
|
||||||
|
|
Loading…
Reference in New Issue