feat(server): support repetition penalty (#47)
This commit is contained in:
parent
2ad895a6cc
commit
313194f6d7
|
@ -15,7 +15,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
|
|||
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
||||
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
||||
- 45ms per token generation for BLOOM with 8xA100 80GB
|
||||
- Logits warpers (temperature scaling, topk ...)
|
||||
- Logits warpers (temperature scaling, topk, repetition penalty ...)
|
||||
- Stop sequences
|
||||
- Log probabilities
|
||||
|
||||
|
|
|
@ -38,6 +38,8 @@ message NextTokenChooserParameters {
|
|||
bool do_sample = 4;
|
||||
/// random seed for sampling
|
||||
uint64 seed = 5;
|
||||
/// repetition penalty
|
||||
float repetition_penalty = 6;
|
||||
}
|
||||
|
||||
message StoppingCriteriaParameters {
|
||||
|
|
|
@ -13,6 +13,8 @@ use validation::Validation;
|
|||
pub(crate) struct GenerateParameters {
|
||||
#[serde(default = "default_temperature")]
|
||||
pub temperature: f32,
|
||||
#[serde(default = "default_repetition_penalty")]
|
||||
pub repetition_penalty: f32,
|
||||
#[serde(default = "default_top_k")]
|
||||
pub top_k: i32,
|
||||
#[serde(default = "default_top_p")]
|
||||
|
@ -32,6 +34,9 @@ pub(crate) struct GenerateParameters {
|
|||
fn default_temperature() -> f32 {
|
||||
1.0
|
||||
}
|
||||
fn default_repetition_penalty() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
fn default_top_k() -> i32 {
|
||||
0
|
||||
|
@ -52,6 +57,7 @@ fn default_max_new_tokens() -> u32 {
|
|||
fn default_parameters() -> GenerateParameters {
|
||||
GenerateParameters {
|
||||
temperature: default_temperature(),
|
||||
repetition_penalty: default_repetition_penalty(),
|
||||
top_k: default_top_k(),
|
||||
top_p: default_top_p(),
|
||||
do_sample: default_do_sample(),
|
||||
|
|
|
@ -33,6 +33,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
|||
inputs: "liveness".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
temperature: 1.0,
|
||||
repetition_penalty: 1.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
do_sample: false,
|
||||
|
|
|
@ -113,6 +113,9 @@ fn validate(
|
|||
if request.parameters.temperature <= 0.0 {
|
||||
return Err(ValidationError::Temperature);
|
||||
}
|
||||
if request.parameters.repetition_penalty <= 0.0 {
|
||||
return Err(ValidationError::RepetitionPenalty);
|
||||
}
|
||||
if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
|
||||
return Err(ValidationError::TopP);
|
||||
}
|
||||
|
@ -146,6 +149,7 @@ fn validate(
|
|||
// Return ValidGenerateRequest
|
||||
let GenerateParameters {
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
top_k,
|
||||
top_p,
|
||||
do_sample,
|
||||
|
@ -156,6 +160,7 @@ fn validate(
|
|||
|
||||
let parameters = NextTokenChooserParameters {
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
top_k: top_k as u32,
|
||||
top_p,
|
||||
do_sample,
|
||||
|
@ -195,6 +200,8 @@ pub(crate) struct ValidGenerateRequest {
|
|||
pub enum ValidationError {
|
||||
#[error("temperature must be strictly positive")]
|
||||
Temperature,
|
||||
#[error("repetition_penalty must be strictly positive")]
|
||||
RepetitionPenalty,
|
||||
#[error("top_p must be > 0.0 and <= 1.0")]
|
||||
TopP,
|
||||
#[error("top_k must be strictly positive")]
|
||||
|
|
|
@ -7,6 +7,7 @@ from text_generation.pb import generate_pb2
|
|||
def default_pb_parameters():
|
||||
return generate_pb2.NextTokenChooserParameters(
|
||||
temperature=1.0,
|
||||
repetition_penalty=1.0,
|
||||
top_k=0,
|
||||
top_p=1.0,
|
||||
do_sample=False,
|
||||
|
|
|
@ -336,7 +336,7 @@ class CausalLM(Model):
|
|||
all_input_ids,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
tokens, logprobs = next_token_chooser(all_input_ids, logits)
|
||||
tokens, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits)
|
||||
next_token_id = tokens[-1].view(1, 1)
|
||||
|
||||
# Append next token to all tokens
|
||||
|
|
|
@ -418,7 +418,9 @@ class Seq2SeqLM(Model):
|
|||
decoder_input_ids,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits)
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
decoder_input_ids.view(1, -1), logits
|
||||
)
|
||||
|
||||
# Append next token to decoder tokens
|
||||
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
|
||||
|
|
|
@ -17,6 +17,7 @@ from typing import List, Optional, Tuple
|
|||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.generation.logits_process import (
|
||||
LogitsProcessorList,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
|
@ -48,6 +49,7 @@ class NextTokenChooser:
|
|||
def __init__(
|
||||
self,
|
||||
temperature=1.0,
|
||||
repetition_penalty=1.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
do_sample=False,
|
||||
|
@ -68,6 +70,9 @@ class NextTokenChooser:
|
|||
if top_p is not None and top_p < 1.0:
|
||||
warpers.append(TopPLogitsWarper(top_p=top_p))
|
||||
sampling = True
|
||||
if repetition_penalty is not None and repetition_penalty != 1.0:
|
||||
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
||||
sampling = True
|
||||
|
||||
self.warpers = warpers
|
||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||
|
@ -75,8 +80,10 @@ class NextTokenChooser:
|
|||
def __call__(self, input_ids, scores):
|
||||
# Warp logits
|
||||
scores = self.warpers(input_ids, scores)
|
||||
|
||||
# Compute logprobs
|
||||
logprobs = torch.log_softmax(scores, -1)
|
||||
|
||||
# Choose tokens
|
||||
next_ids = self.choice(scores)
|
||||
return next_ids, logprobs
|
||||
|
@ -87,6 +94,7 @@ class NextTokenChooser:
|
|||
) -> "NextTokenChooser":
|
||||
return NextTokenChooser(
|
||||
temperature=pb.temperature,
|
||||
repetition_penalty=pb.repetition_penalty,
|
||||
top_k=pb.top_k,
|
||||
top_p=pb.top_p,
|
||||
do_sample=pb.do_sample,
|
||||
|
|
Loading…
Reference in New Issue