From 1a2d68250aa7dfbe1fa52b22eec07edfb7b895fb Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 9 Mar 2023 11:33:57 +0100 Subject: [PATCH] feat: support typical sampling (#114) closes #112 --- proto/generate.proto | 10 +++-- router/src/lib.rs | 10 +++++ router/src/queue.rs | 1 + router/src/server.rs | 1 + router/src/validation.rs | 44 +++++++++++++------ server/tests/conftest.py | 1 + server/text_generation_server/utils/tokens.py | 6 +++ 7 files changed, 55 insertions(+), 18 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index dccd7e59..a47e2ec1 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -34,14 +34,16 @@ message NextTokenChooserParameters { uint32 top_k = 2; /// restricting to top tokens summing to prob_cut_off <= prob_cut_off float top_p = 3; + /// restricting to top tokens summing to prob_cut_off <= prob_cut_off + float typical_p = 4; /// apply sampling on the logits - bool do_sample = 4; + bool do_sample = 5; /// random seed for sampling - uint64 seed = 5; + uint64 seed = 6; /// repetition penalty - float repetition_penalty = 6; + float repetition_penalty = 7; /// token watermarking using "A Watermark for Large Language Models" - bool watermark = 7; + bool watermark = 8; } message StoppingCriteriaParameters { diff --git a/router/src/lib.rs b/router/src/lib.rs index 1819541d..d375eafb 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -41,6 +41,15 @@ pub(crate) struct GenerateParameters { )] pub top_p: Option, #[serde(default)] + #[schema( + exclusive_minimum = 0.0, + maximum = 1.0, + nullable = true, + default = "null", + example = 0.95 + )] + pub typical_p: Option, + #[serde(default)] #[schema(default = "false", example = true)] pub do_sample: bool, #[serde(default = "default_max_new_tokens")] @@ -72,6 +81,7 @@ fn default_parameters() -> GenerateParameters { repetition_penalty: None, top_k: None, top_p: None, + typical_p: None, do_sample: false, max_new_tokens: default_max_new_tokens(), return_full_text: None, diff --git a/router/src/queue.rs b/router/src/queue.rs index 0ebfed9b..db3c509e 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -231,6 +231,7 @@ mod tests { temperature: 0.0, top_k: 0, top_p: 0.0, + typical_p: 0.0, do_sample: false, seed: 0, repetition_penalty: 0.0, diff --git a/router/src/server.rs b/router/src/server.rs index b8b6b440..2ce5699d 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -68,6 +68,7 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json 1.0 { - return Err(ValidationError::TopP); - } - - // Different because the proto default value is 0 while it is not a valid value + // Different because the proto default value is not a valid value // for the user - let top_k: u32 = match top_k { - None => Ok(0), - Some(top_k) => { - if top_k <= 0 { + let top_p = top_p + .map(|value| { + if value <= 0.0 || value >= 1.0 { + return Err(ValidationError::TopP); + } + Ok(value) + }) + .unwrap_or(Ok(1.0))?; + + let typical_p = typical_p + .map(|value| { + if value <= 0.0 || value >= 1.0 { + return Err(ValidationError::TypicalP); + } + Ok(value) + }) + .unwrap_or(Ok(1.0))?; + + let top_k: u32 = top_k + .map(|value| { + if value <= 0 { return Err(ValidationError::TopK); } - Ok(top_k as u32) - } - }?; + Ok(value as u32) + }) + .unwrap_or(Ok(0))?; if max_new_tokens == 0 { return Err(ValidationError::MaxNewTokens); @@ -231,6 +244,7 @@ fn validate( repetition_penalty, top_k, top_p, + typical_p, do_sample, seed, watermark, @@ -275,10 +289,12 @@ pub enum ValidationError { Temperature, #[error("`repetition_penalty` must be strictly positive")] RepetitionPenalty, - #[error("`top_p` must be > 0.0 and <= 1.0")] + #[error("`top_p` must be > 0.0 and < 1.0")] TopP, #[error("`top_k` must be strictly positive")] TopK, + #[error("`typical_p` must be > 0.0 and < 1.0")] + TypicalP, #[error("`max_new_tokens` must be strictly positive")] MaxNewTokens, #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")] diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 04c909ef..16d2c408 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -10,6 +10,7 @@ def default_pb_parameters(): repetition_penalty=1.0, top_k=0, top_p=1.0, + typical_p=1.0, do_sample=False, ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index aa76b6eb..15563761 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -6,6 +6,7 @@ from transformers import ( TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, + TypicalLogitsWarper, RepetitionPenaltyLogitsProcessor, PreTrainedTokenizerBase, ) @@ -41,6 +42,7 @@ class NextTokenChooser: repetition_penalty=1.0, top_k=None, top_p=None, + typical_p=None, do_sample=False, seed=0, device="cpu", @@ -64,6 +66,9 @@ class NextTokenChooser: if top_p is not None and top_p < 1.0: warpers.append(TopPLogitsWarper(top_p=top_p)) sampling = True + if typical_p is not None and typical_p < 1.0: + warpers.append(TypicalLogitsWarper(mass=typical_p)) + sampling = True self.warpers = warpers self.choice = Sampling(seed, device) if sampling else Greedy() @@ -92,6 +97,7 @@ class NextTokenChooser: repetition_penalty=pb.repetition_penalty, top_k=pb.top_k, top_p=pb.top_p, + typical_p=pb.typical_p, do_sample=pb.do_sample, seed=pb.seed, device=device,