diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index c33d64e6..4afce951 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -31,6 +31,7 @@ pub async fn run( typical_p: Option, repetition_penalty: Option, frequency_penalty: Option, + no_repeat_ngram_size: Option, watermark: bool, do_sample: bool, client: ShardedClient, @@ -44,6 +45,7 @@ pub async fn run( seed: 0, repetition_penalty: repetition_penalty.unwrap_or(1.0), frequency_penalty: frequency_penalty.unwrap_or(0.0), + no_repeat_ngram_size: no_repeat_ngram_size.unwrap_or(0), watermark, grammar: String::new(), grammar_type: GrammarType::None as i32, @@ -145,6 +147,7 @@ pub async fn run( typical_p, repetition_penalty, frequency_penalty, + no_repeat_ngram_size, watermark, do_sample, ); diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 2ee3d7c5..3eece4eb 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -89,6 +89,11 @@ struct Args { #[clap(long, env)] frequency_penalty: Option, + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + no_repeat_ngram_size: Option, + /// Generation parameter in case you want to specifically test/debug particular /// decoding strategies, for full doc refer to the `text-generation-server` #[clap(long, env)] @@ -125,6 +130,7 @@ fn main() -> Result<(), Box> { typical_p, repetition_penalty, frequency_penalty, + no_repeat_ngram_size, watermark, do_sample, master_shard_uds_path, @@ -196,6 +202,7 @@ fn main() -> Result<(), Box> { typical_p, repetition_penalty, frequency_penalty, + no_repeat_ngram_size, watermark, do_sample, sharded_client, diff --git a/router/client/src/v2/client.rs b/router/client/src/v2/client.rs index 9a2e6ac7..2fd5df12 100644 --- a/router/client/src/v2/client.rs +++ b/router/client/src/v2/client.rs @@ -143,6 +143,7 @@ impl Client { seed: 0, repetition_penalty: 1.2, frequency_penalty: 0.1, + no_repeat_ngram_size: 0, watermark: true, grammar: String::new(), grammar_type: GrammarType::None as i32, diff --git a/router/client/src/v2/sharded_client.rs b/router/client/src/v2/sharded_client.rs index 7b24aec3..b53deefe 100644 --- a/router/client/src/v2/sharded_client.rs +++ b/router/client/src/v2/sharded_client.rs @@ -228,6 +228,7 @@ impl Health for ShardedClient { seed: 0, repetition_penalty: 1.0, frequency_penalty: 0.0, + no_repeat_ngram_size: 0, watermark: false, grammar: String::new(), grammar_type: GrammarType::None as i32, diff --git a/router/client/src/v3/client.rs b/router/client/src/v3/client.rs index a996b14f..1738cd91 100644 --- a/router/client/src/v3/client.rs +++ b/router/client/src/v3/client.rs @@ -166,6 +166,7 @@ impl Client { seed: 0, repetition_penalty: 1.2, frequency_penalty: 0.1, + no_repeat_ngram_size: 0, watermark: true, grammar: String::new(), grammar_type: GrammarType::None as i32, diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index ae8a899b..2a94da8d 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -231,6 +231,7 @@ impl Health for ShardedClient { seed: 0, repetition_penalty: 1.0, frequency_penalty: 0.0, + no_repeat_ngram_size: 0, watermark: false, grammar: String::new(), grammar_type: GrammarType::None as i32, diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index 0b51645a..a4e7e0f1 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -377,6 +377,7 @@ impl From for NextTokenChooserParameters { seed: value.seed, repetition_penalty: value.repetition_penalty, frequency_penalty: value.frequency_penalty, + no_repeat_ngram_size: value.no_repeat_ngram_size, watermark: value.watermark, grammar, grammar_type: grammar_type.into(), @@ -420,6 +421,7 @@ mod tests { seed: 0, repetition_penalty: 0.0, frequency_penalty: 0.0, + no_repeat_ngram_size: 0, watermark: false, grammar: None, }, diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 894d9cab..0db6c501 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -440,6 +440,7 @@ impl From for NextTokenChooserParameters { seed: value.seed, repetition_penalty: value.repetition_penalty, frequency_penalty: value.frequency_penalty, + no_repeat_ngram_size: value.no_repeat_ngram_size, watermark: value.watermark, grammar, grammar_type: grammar_type.into(), @@ -483,6 +484,7 @@ mod tests { seed: 0, repetition_penalty: 0.0, frequency_penalty: 0.0, + no_repeat_ngram_size: 0, watermark: false, grammar: None, }, diff --git a/router/src/validation.rs b/router/src/validation.rs index d3f433a7..507fab6e 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -200,6 +200,7 @@ impl Validation { temperature, repetition_penalty, frequency_penalty, + no_repeat_ngram_size, top_k, top_p, typical_p, @@ -243,6 +244,8 @@ impl Validation { return Err(ValidationError::FrequencyPenalty); } + let no_repeat_ngram_size = no_repeat_ngram_size.unwrap_or(0); + // Different because the proto default value is not a valid value // for the user let top_p = top_p @@ -370,6 +373,7 @@ impl Validation { temperature, repetition_penalty, frequency_penalty, + no_repeat_ngram_size, top_k, top_p, typical_p,