parent
941cd42e0c
commit
1a2d68250a
|
@ -34,14 +34,16 @@ message NextTokenChooserParameters {
|
|||
uint32 top_k = 2;
|
||||
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||
float top_p = 3;
|
||||
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||
float typical_p = 4;
|
||||
/// apply sampling on the logits
|
||||
bool do_sample = 4;
|
||||
bool do_sample = 5;
|
||||
/// random seed for sampling
|
||||
uint64 seed = 5;
|
||||
uint64 seed = 6;
|
||||
/// repetition penalty
|
||||
float repetition_penalty = 6;
|
||||
float repetition_penalty = 7;
|
||||
/// token watermarking using "A Watermark for Large Language Models"
|
||||
bool watermark = 7;
|
||||
bool watermark = 8;
|
||||
}
|
||||
|
||||
message StoppingCriteriaParameters {
|
||||
|
|
|
@ -41,6 +41,15 @@ pub(crate) struct GenerateParameters {
|
|||
)]
|
||||
pub top_p: Option<f32>,
|
||||
#[serde(default)]
|
||||
#[schema(
|
||||
exclusive_minimum = 0.0,
|
||||
maximum = 1.0,
|
||||
nullable = true,
|
||||
default = "null",
|
||||
example = 0.95
|
||||
)]
|
||||
pub typical_p: Option<f32>,
|
||||
#[serde(default)]
|
||||
#[schema(default = "false", example = true)]
|
||||
pub do_sample: bool,
|
||||
#[serde(default = "default_max_new_tokens")]
|
||||
|
@ -72,6 +81,7 @@ fn default_parameters() -> GenerateParameters {
|
|||
repetition_penalty: None,
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
typical_p: None,
|
||||
do_sample: false,
|
||||
max_new_tokens: default_max_new_tokens(),
|
||||
return_full_text: None,
|
||||
|
|
|
@ -231,6 +231,7 @@ mod tests {
|
|||
temperature: 0.0,
|
||||
top_k: 0,
|
||||
top_p: 0.0,
|
||||
typical_p: 0.0,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 0.0,
|
||||
|
|
|
@ -68,6 +68,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
|||
repetition_penalty: None,
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
typical_p: None,
|
||||
do_sample: false,
|
||||
max_new_tokens: 1,
|
||||
return_full_text: None,
|
||||
|
|
|
@ -153,6 +153,7 @@ fn validate(
|
|||
repetition_penalty,
|
||||
top_k,
|
||||
top_p,
|
||||
typical_p,
|
||||
do_sample,
|
||||
max_new_tokens,
|
||||
stop: stop_sequences,
|
||||
|
@ -171,22 +172,34 @@ fn validate(
|
|||
return Err(ValidationError::RepetitionPenalty);
|
||||
}
|
||||
|
||||
let top_p = top_p.unwrap_or(1.0);
|
||||
if top_p <= 0.0 || top_p > 1.0 {
|
||||
return Err(ValidationError::TopP);
|
||||
}
|
||||
|
||||
// Different because the proto default value is 0 while it is not a valid value
|
||||
// Different because the proto default value is not a valid value
|
||||
// for the user
|
||||
let top_k: u32 = match top_k {
|
||||
None => Ok(0),
|
||||
Some(top_k) => {
|
||||
if top_k <= 0 {
|
||||
let top_p = top_p
|
||||
.map(|value| {
|
||||
if value <= 0.0 || value >= 1.0 {
|
||||
return Err(ValidationError::TopP);
|
||||
}
|
||||
Ok(value)
|
||||
})
|
||||
.unwrap_or(Ok(1.0))?;
|
||||
|
||||
let typical_p = typical_p
|
||||
.map(|value| {
|
||||
if value <= 0.0 || value >= 1.0 {
|
||||
return Err(ValidationError::TypicalP);
|
||||
}
|
||||
Ok(value)
|
||||
})
|
||||
.unwrap_or(Ok(1.0))?;
|
||||
|
||||
let top_k: u32 = top_k
|
||||
.map(|value| {
|
||||
if value <= 0 {
|
||||
return Err(ValidationError::TopK);
|
||||
}
|
||||
Ok(top_k as u32)
|
||||
}
|
||||
}?;
|
||||
Ok(value as u32)
|
||||
})
|
||||
.unwrap_or(Ok(0))?;
|
||||
|
||||
if max_new_tokens == 0 {
|
||||
return Err(ValidationError::MaxNewTokens);
|
||||
|
@ -231,6 +244,7 @@ fn validate(
|
|||
repetition_penalty,
|
||||
top_k,
|
||||
top_p,
|
||||
typical_p,
|
||||
do_sample,
|
||||
seed,
|
||||
watermark,
|
||||
|
@ -275,10 +289,12 @@ pub enum ValidationError {
|
|||
Temperature,
|
||||
#[error("`repetition_penalty` must be strictly positive")]
|
||||
RepetitionPenalty,
|
||||
#[error("`top_p` must be > 0.0 and <= 1.0")]
|
||||
#[error("`top_p` must be > 0.0 and < 1.0")]
|
||||
TopP,
|
||||
#[error("`top_k` must be strictly positive")]
|
||||
TopK,
|
||||
#[error("`typical_p` must be > 0.0 and < 1.0")]
|
||||
TypicalP,
|
||||
#[error("`max_new_tokens` must be strictly positive")]
|
||||
MaxNewTokens,
|
||||
#[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
|
||||
|
|
|
@ -10,6 +10,7 @@ def default_pb_parameters():
|
|||
repetition_penalty=1.0,
|
||||
top_k=0,
|
||||
top_p=1.0,
|
||||
typical_p=1.0,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ from transformers import (
|
|||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
|
@ -41,6 +42,7 @@ class NextTokenChooser:
|
|||
repetition_penalty=1.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
typical_p=None,
|
||||
do_sample=False,
|
||||
seed=0,
|
||||
device="cpu",
|
||||
|
@ -64,6 +66,9 @@ class NextTokenChooser:
|
|||
if top_p is not None and top_p < 1.0:
|
||||
warpers.append(TopPLogitsWarper(top_p=top_p))
|
||||
sampling = True
|
||||
if typical_p is not None and typical_p < 1.0:
|
||||
warpers.append(TypicalLogitsWarper(mass=typical_p))
|
||||
sampling = True
|
||||
|
||||
self.warpers = warpers
|
||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||
|
@ -92,6 +97,7 @@ class NextTokenChooser:
|
|||
repetition_penalty=pb.repetition_penalty,
|
||||
top_k=pb.top_k,
|
||||
top_p=pb.top_p,
|
||||
typical_p=pb.typical_p,
|
||||
do_sample=pb.do_sample,
|
||||
seed=pb.seed,
|
||||
device=device,
|
||||
|
|
Loading…
Reference in New Issue