feat: support typical sampling (#114)

closes #112
This commit is contained in:
OlivierDehaene 2023-03-09 11:33:57 +01:00 committed by GitHub
parent 941cd42e0c
commit 1a2d68250a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 55 additions and 18 deletions

View File

@ -34,14 +34,16 @@ message NextTokenChooserParameters {
uint32 top_k = 2;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float top_p = 3;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float typical_p = 4;
/// apply sampling on the logits
bool do_sample = 4;
bool do_sample = 5;
/// random seed for sampling
uint64 seed = 5;
uint64 seed = 6;
/// repetition penalty
float repetition_penalty = 6;
float repetition_penalty = 7;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 7;
bool watermark = 8;
}
message StoppingCriteriaParameters {

View File

@ -41,6 +41,15 @@ pub(crate) struct GenerateParameters {
)]
pub top_p: Option<f32>,
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
maximum = 1.0,
nullable = true,
default = "null",
example = 0.95
)]
pub typical_p: Option<f32>,
#[serde(default)]
#[schema(default = "false", example = true)]
pub do_sample: bool,
#[serde(default = "default_max_new_tokens")]
@ -72,6 +81,7 @@ fn default_parameters() -> GenerateParameters {
repetition_penalty: None,
top_k: None,
top_p: None,
typical_p: None,
do_sample: false,
max_new_tokens: default_max_new_tokens(),
return_full_text: None,

View File

@ -231,6 +231,7 @@ mod tests {
temperature: 0.0,
top_k: 0,
top_p: 0.0,
typical_p: 0.0,
do_sample: false,
seed: 0,
repetition_penalty: 0.0,

View File

@ -68,6 +68,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
repetition_penalty: None,
top_k: None,
top_p: None,
typical_p: None,
do_sample: false,
max_new_tokens: 1,
return_full_text: None,

View File

@ -153,6 +153,7 @@ fn validate(
repetition_penalty,
top_k,
top_p,
typical_p,
do_sample,
max_new_tokens,
stop: stop_sequences,
@ -171,22 +172,34 @@ fn validate(
return Err(ValidationError::RepetitionPenalty);
}
let top_p = top_p.unwrap_or(1.0);
if top_p <= 0.0 || top_p > 1.0 {
return Err(ValidationError::TopP);
}
// Different because the proto default value is 0 while it is not a valid value
// Different because the proto default value is not a valid value
// for the user
let top_k: u32 = match top_k {
None => Ok(0),
Some(top_k) => {
if top_k <= 0 {
let top_p = top_p
.map(|value| {
if value <= 0.0 || value >= 1.0 {
return Err(ValidationError::TopP);
}
Ok(value)
})
.unwrap_or(Ok(1.0))?;
let typical_p = typical_p
.map(|value| {
if value <= 0.0 || value >= 1.0 {
return Err(ValidationError::TypicalP);
}
Ok(value)
})
.unwrap_or(Ok(1.0))?;
let top_k: u32 = top_k
.map(|value| {
if value <= 0 {
return Err(ValidationError::TopK);
}
Ok(top_k as u32)
}
}?;
Ok(value as u32)
})
.unwrap_or(Ok(0))?;
if max_new_tokens == 0 {
return Err(ValidationError::MaxNewTokens);
@ -231,6 +244,7 @@ fn validate(
repetition_penalty,
top_k,
top_p,
typical_p,
do_sample,
seed,
watermark,
@ -275,10 +289,12 @@ pub enum ValidationError {
Temperature,
#[error("`repetition_penalty` must be strictly positive")]
RepetitionPenalty,
#[error("`top_p` must be > 0.0 and <= 1.0")]
#[error("`top_p` must be > 0.0 and < 1.0")]
TopP,
#[error("`top_k` must be strictly positive")]
TopK,
#[error("`typical_p` must be > 0.0 and < 1.0")]
TypicalP,
#[error("`max_new_tokens` must be strictly positive")]
MaxNewTokens,
#[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]

View File

@ -10,6 +10,7 @@ def default_pb_parameters():
repetition_penalty=1.0,
top_k=0,
top_p=1.0,
typical_p=1.0,
do_sample=False,
)

View File

@ -6,6 +6,7 @@ from transformers import (
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
RepetitionPenaltyLogitsProcessor,
PreTrainedTokenizerBase,
)
@ -41,6 +42,7 @@ class NextTokenChooser:
repetition_penalty=1.0,
top_k=None,
top_p=None,
typical_p=None,
do_sample=False,
seed=0,
device="cpu",
@ -64,6 +66,9 @@ class NextTokenChooser:
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p))
sampling = True
if typical_p is not None and typical_p < 1.0:
warpers.append(TypicalLogitsWarper(mass=typical_p))
sampling = True
self.warpers = warpers
self.choice = Sampling(seed, device) if sampling else Greedy()
@ -92,6 +97,7 @@ class NextTokenChooser:
repetition_penalty=pb.repetition_penalty,
top_k=pb.top_k,
top_p=pb.top_p,
typical_p=pb.typical_p,
do_sample=pb.do_sample,
seed=pb.seed,
device=device,