From 09b7c26bbdb940e4e0d2216e14fd437f89fcdeb2 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 8 Feb 2024 18:41:25 +0100 Subject: [PATCH] feat(server): add frequency penalty (#1541) --- Cargo.lock | 38 ++++++- benchmark/src/lib.rs | 3 + benchmark/src/main.rs | 7 ++ benchmark/src/table.rs | 2 + integration-tests/models/test_mamba.py | 11 +- proto/generate.proto | 2 + router/client/src/client.rs | 1 + router/src/health.rs | 1 + router/src/lib.rs | 86 ++++++++++++-- router/src/queue.rs | 1 + router/src/server.rs | 18 ++- router/src/validation.rs | 9 ++ server/tests/utils/test_tokens.py | 2 +- .../text_generation_server/models/__init__.py | 1 + .../models/causal_lm.py | 7 +- .../models/custom_modeling/mamba_modeling.py | 107 ++++++++++++------ .../models/flash_causal_lm.py | 8 +- server/text_generation_server/models/mamba.py | 59 +++++++--- .../models/seq2seq_lm.py | 7 +- server/text_generation_server/models/types.py | 4 +- .../utils/logits_process.py | 56 +++++++++ server/text_generation_server/utils/tokens.py | 55 +++++++-- 22 files changed, 396 insertions(+), 89 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7fdf301a..3318f3b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2787,7 +2787,7 @@ dependencies = [ "tabled", "text-generation-client", "thiserror", - "tokenizers", + "tokenizers 0.14.1", "tokio", "tracing", "tracing-subscriber", @@ -2850,7 +2850,7 @@ dependencies = [ "serde_json", "text-generation-client", "thiserror", - "tokenizers", + "tokenizers 0.15.1", "tokio", "tokio-stream", "tower-http", @@ -2972,6 +2972,40 @@ dependencies = [ "unicode_categories", ] +[[package]] +name = "tokenizers" +version = "0.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db445cceba5dfeb0f9702be7d6bfd91801ddcbe8fe8722defe7f2e96da75812" +dependencies = [ + "aho-corasick", + "clap", + "derive_builder", + "esaxx-rs", + "getrandom", + "hf-hub", + "indicatif", + "itertools 0.11.0", + "lazy_static", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand", + "rayon", + "rayon-cond", + "regex", + "regex-syntax 0.7.5", + "serde", + "serde_json", + "spm_precompiled", + "thiserror", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + [[package]] name = "tokio" version = "1.35.1" diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index 1875652c..6deae48d 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -30,6 +30,7 @@ pub async fn run( top_p: Option, typical_p: Option, repetition_penalty: Option, + frequency_penalty: Option, watermark: bool, do_sample: bool, client: ShardedClient, @@ -42,6 +43,7 @@ pub async fn run( do_sample, seed: 0, repetition_penalty: repetition_penalty.unwrap_or(1.0), + frequency_penalty: frequency_penalty.unwrap_or(0.0), watermark, }; @@ -140,6 +142,7 @@ pub async fn run( top_p, typical_p, repetition_penalty, + frequency_penalty, watermark, do_sample, ); diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 97c8af1c..2d89e045 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -84,6 +84,11 @@ struct Args { #[clap(long, env)] repetition_penalty: Option, + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + frequency_penalty: Option, + /// Generation parameter in case you want to specifically test/debug particular /// decoding strategies, for full doc refer to the `text-generation-server` #[clap(long, env)] @@ -119,6 +124,7 @@ fn main() -> Result<(), Box> { top_p, typical_p, repetition_penalty, + frequency_penalty, watermark, do_sample, master_shard_uds_path, @@ -187,6 +193,7 @@ fn main() -> Result<(), Box> { top_p, typical_p, repetition_penalty, + frequency_penalty, watermark, do_sample, sharded_client, diff --git a/benchmark/src/table.rs b/benchmark/src/table.rs index 9e36717b..c4819ff3 100644 --- a/benchmark/src/table.rs +++ b/benchmark/src/table.rs @@ -15,6 +15,7 @@ pub(crate) fn parameters_table( top_p: Option, typical_p: Option, repetition_penalty: Option, + frequency_penalty: Option, watermark: bool, do_sample: bool, ) -> Table { @@ -33,6 +34,7 @@ pub(crate) fn parameters_table( builder.push_record(["Top P", &format!("{top_p:?}")]); builder.push_record(["Typical P", &format!("{typical_p:?}")]); builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]); + builder.push_record(["Frequency Penalty", &format!("{frequency_penalty:?}")]); builder.push_record(["Watermark", &watermark.to_string()]); builder.push_record(["Do Sample", &do_sample.to_string()]); diff --git a/integration-tests/models/test_mamba.py b/integration-tests/models/test_mamba.py index d86faeff..bf398999 100644 --- a/integration-tests/models/test_mamba.py +++ b/integration-tests/models/test_mamba.py @@ -24,6 +24,7 @@ async def test_mamba(fused_kernel_mamba, response_snapshot): assert response.generated_text == "\n\nDeep learning is a new type of machine" assert response == response_snapshot + @pytest.mark.asyncio @pytest.mark.private async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): @@ -44,13 +45,19 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): ) assert response.details.generated_tokens == 10 - assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in" + assert ( + response.generated_text + == "blue, red, yellow, \nand orange (in the order they appear in" + ) assert response == response_snapshot + @pytest.mark.asyncio @pytest.mark.private async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot): - responses = await generate_load(fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4) + responses = await generate_load( + fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4 + ) assert len(responses) == 4 assert all([r.generated_text == responses[0].generated_text for r in responses]) diff --git a/proto/generate.proto b/proto/generate.proto index 1f30df38..5140fdaa 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -66,6 +66,8 @@ message NextTokenChooserParameters { uint64 seed = 6; /// repetition penalty float repetition_penalty = 7; + /// frequency penalty + float frequency_penalty = 9; /// token watermarking using "A Watermark for Large Language Models" bool watermark = 8; } diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 023c5671..fde5c402 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -125,6 +125,7 @@ impl Client { do_sample: false, seed: 0, repetition_penalty: 1.2, + frequency_penalty: 0.1, watermark: true, }), stopping_parameters: Some(StoppingCriteriaParameters { diff --git a/router/src/health.rs b/router/src/health.rs index ab290fc1..e830a3c3 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -43,6 +43,7 @@ impl Health { do_sample: false, seed: 0, repetition_penalty: 1.0, + frequency_penalty: 0.0, watermark: false, }), stopping_parameters: Some(StoppingCriteriaParameters { diff --git a/router/src/lib.rs b/router/src/lib.rs index e85519cc..7c44d642 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -106,6 +106,14 @@ pub(crate) struct GenerateParameters { )] pub repetition_penalty: Option, #[serde(default)] + #[schema( + exclusive_minimum = -2.0, + nullable = true, + default = "null", + example = 0.1 + )] + pub frequency_penalty: Option, + #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)] pub top_k: Option, #[serde(default)] @@ -172,6 +180,7 @@ fn default_parameters() -> GenerateParameters { best_of: None, temperature: None, repetition_penalty: None, + frequency_penalty: None, top_k: None, top_p: None, typical_p: None, @@ -205,10 +214,71 @@ pub(crate) struct ChatCompletion { pub(crate) struct ChatCompletionComplete { pub index: u32, pub message: Message, - pub logprobs: Option>, + pub logprobs: Option, pub finish_reason: String, } +#[derive(Clone, Deserialize, Serialize, ToSchema)] +pub(crate) struct ChatCompletionLogprobs { + content: Vec, +} + +impl From<(Token, Vec)> for ChatCompletionLogprobs { + fn from(value: (Token, Vec)) -> Self { + let (token, top_tokens) = value; + + Self { + content: vec![ChatCompletionLogprob { + token: token.text, + logprob: token.logprob, + top_logprobs: top_tokens + .into_iter() + .map(|t| ChatCompletionTopLogprob { + token: t.text, + logprob: t.logprob, + }) + .collect(), + }], + } + } +} + +impl From<(Vec, Vec>)> for ChatCompletionLogprobs { + fn from(value: (Vec, Vec>)) -> Self { + let (tokens, top_tokens) = value; + Self { + content: tokens + .into_iter() + .zip(top_tokens) + .map(|(t, top_t)| ChatCompletionLogprob { + token: t.text, + logprob: t.logprob, + top_logprobs: top_t + .into_iter() + .map(|t| ChatCompletionTopLogprob { + token: t.text, + logprob: t.logprob, + }) + .collect(), + }) + .collect(), + } + } +} + +#[derive(Clone, Deserialize, Serialize, ToSchema)] +pub(crate) struct ChatCompletionLogprob { + token: String, + logprob: f32, + top_logprobs: Vec, +} + +#[derive(Clone, Deserialize, Serialize, ToSchema)] +pub(crate) struct ChatCompletionTopLogprob { + token: String, + logprob: f32, +} + #[derive(Clone, Deserialize, Serialize)] pub(crate) struct Usage { pub prompt_tokens: u32, @@ -238,7 +308,7 @@ impl ChatCompletion { content: output, }, logprobs: return_logprobs - .then(|| details.tokens.iter().map(|t| t.logprob).collect()), + .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), finish_reason: details.finish_reason.to_string(), }], usage: Usage { @@ -266,7 +336,7 @@ pub(crate) struct ChatCompletionChunk { pub(crate) struct ChatCompletionChoice { pub index: u32, pub delta: ChatCompletionDelta, - pub logprobs: Option, + pub logprobs: Option, pub finish_reason: Option, } @@ -285,7 +355,7 @@ impl ChatCompletionChunk { delta: String, created: u64, index: u32, - logprobs: Option, + logprobs: Option, finish_reason: Option, ) -> Self { Self { @@ -319,8 +389,8 @@ pub(crate) struct ChatRequest { /// UNUSED #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. - pub model: String, /* NOTE: UNUSED */ - + pub model: String, + /* NOTE: UNUSED */ /// A list of messages comprising the conversation so far. #[serde(default = "default_request_messages")] pub messages: Vec, @@ -346,7 +416,6 @@ pub(crate) struct ChatRequest { #[schema(example = "false")] pub logprobs: Option, - /// UNUSED /// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with /// an associated log probability. logprobs must be set to true if this parameter is used. #[serde(default)] @@ -365,7 +434,6 @@ pub(crate) struct ChatRequest { #[schema(nullable = true, example = "2")] pub n: Option, - /// UNUSED /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, /// increasing the model's likelihood to talk about new topics #[serde(default)] @@ -447,7 +515,7 @@ pub struct PrefillToken { logprob: f32, } -#[derive(Debug, Serialize, ToSchema)] +#[derive(Debug, Serialize, ToSchema, Clone)] pub struct Token { #[schema(example = 0)] id: u32, diff --git a/router/src/queue.rs b/router/src/queue.rs index 106cacc4..73a7169b 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -355,6 +355,7 @@ mod tests { do_sample: false, seed: 0, repetition_penalty: 0.0, + frequency_penalty: 0.0, watermark: false, }, stopping_parameters: StoppingCriteriaParameters { diff --git a/router/src/server.rs b/router/src/server.rs index b4d26158..acfdef91 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -4,9 +4,10 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; use crate::{ BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta, - ChatRequest, CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters, - GenerateRequest, GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, - PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Validation, + ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse, + FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, + HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, + StreamResponse, Token, TokenizeResponse, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -570,8 +571,8 @@ async fn chat_completions( let stream = req.stream; let max_new_tokens = req.max_tokens.or(Some(100)); let repetition_penalty = req - .frequency_penalty - // rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0) + .presence_penalty + // rescale repetition_penalty from (-2.0, 2.0) to (0.0, 4.0) .map(|x| x + 2.0); let logprobs = req.logprobs.unwrap_or(false); let seed = req.seed; @@ -599,6 +600,7 @@ async fn chat_completions( best_of: None, temperature: req.temperature, repetition_penalty, + frequency_penalty: req.frequency_penalty, top_k: None, top_p: req.top_p, typical_p: None, @@ -630,6 +632,10 @@ async fn chat_completions( .unwrap_or_else(|_| std::time::Duration::from_secs(0)) .as_secs(); + let logprobs = logprobs.then(|| { + ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens)) + }); + event .json_data(ChatCompletionChunk::new( model_id.clone(), @@ -637,7 +643,7 @@ async fn chat_completions( stream_token.token.text, current_time, stream_token.index, - logprobs.then_some(stream_token.token.logprob), + logprobs, stream_token.details.map(|d| d.finish_reason.to_string()), )) .map_or_else( diff --git a/router/src/validation.rs b/router/src/validation.rs index 750b98e5..e6874b11 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -170,6 +170,7 @@ impl Validation { best_of, temperature, repetition_penalty, + frequency_penalty, top_k, top_p, typical_p, @@ -206,6 +207,11 @@ impl Validation { return Err(ValidationError::RepetitionPenalty); } + let frequency_penalty = frequency_penalty.unwrap_or(0.0); + if !(-2.0..=2.0).contains(&frequency_penalty) { + return Err(ValidationError::FrequencyPenalty); + } + // Different because the proto default value is not a valid value // for the user let top_p = top_p @@ -289,6 +295,7 @@ impl Validation { let parameters = NextTokenChooserParameters { temperature, repetition_penalty, + frequency_penalty, top_k, top_p, typical_p, @@ -420,6 +427,8 @@ pub enum ValidationError { Temperature, #[error("`repetition_penalty` must be strictly positive")] RepetitionPenalty, + #[error("`frequency_penalty` must be >= -2.0 and <= 2.0")] + FrequencyPenalty, #[error("`top_p` must be > 0.0 and < 1.0")] TopP, #[error("`top_k` must be strictly positive")] diff --git a/server/tests/utils/test_tokens.py b/server/tests/utils/test_tokens.py index d3f2d766..5db32776 100644 --- a/server/tests/utils/test_tokens.py +++ b/server/tests/utils/test_tokens.py @@ -70,7 +70,7 @@ def test_batch_top_tokens(): # Now let's make second member of the batch be speculated inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2) - accepted_ids[1] = 2 + accepted_ids[1] = 2 topn_tok_ids, topn_tok_logprobs = batch_top_tokens( top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids ) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index a952f060..da7d8416 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -86,6 +86,7 @@ except ImportError as e: if MAMBA_AVAILABLE: __all__.append(Mamba) + def get_model( model_id: str, revision: Optional[str], diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 29e9f8b1..a7a16212 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -696,14 +696,17 @@ class CausalLM(Model): if top_n_tokens > 0: all_top_tokens = [] - for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs): + for (top_token_ids, top_token_logprobs) in zip( + top_token_ids, top_token_logprobs + ): toptoken_texts = self.tokenizer.batch_decode( top_token_ids, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) special_toptokens = [ - token_id in self.all_special_ids for token_id in top_token_ids + token_id in self.all_special_ids + for token_id in top_token_ids ] top_tokens = Tokens( top_token_ids, diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index 1773f04d..017c0341 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -19,6 +19,7 @@ from einops import rearrange from causal_conv1d import causal_conv1d_fn, causal_conv1d_update import math + class MambaConfig(PretrainedConfig): def __init__( self, @@ -53,6 +54,7 @@ class MambaConfig(PretrainedConfig): **kwargs, ) + class MambaBlock(nn.Module): def __init__(self, prefix, config, weights): super().__init__() @@ -60,10 +62,14 @@ class MambaBlock(nn.Module): self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False) self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False) self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True) - self.dt_proj_no_bias = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=False) - self.out_proj = FastLinear.load(config, f"{prefix}.out_proj", weights, bias=False) + self.dt_proj_no_bias = FastLinear.load( + config, f"{prefix}.dt_proj", weights, bias=False + ) + self.out_proj = FastLinear.load( + config, f"{prefix}.out_proj", weights, bias=False + ) self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True) - self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float()) + self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float()) self.D = weights.get_tensor(f"{prefix}.D") self.activation = "silu" self.dt_rank = config.dt_rank @@ -80,12 +86,14 @@ class MambaBlock(nn.Module): out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state) return out, conv_state, ssm_state - projected_states = self.in_proj(hidden_states).transpose(1,2) + projected_states = self.in_proj(hidden_states).transpose(1, 2) x, z = projected_states.chunk(2, dim=1) conv_state = F.pad(x, (self.d_conv - seqlen, 0)) x = causal_conv1d_fn( x=x, - weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), + weight=self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ), bias=self.conv1d.bias, activation=self.activation, ) @@ -94,7 +102,9 @@ class MambaBlock(nn.Module): # We want dt to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) - dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt, B, C = torch.split( + x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1 + ) dt = self.dt_proj.weight @ dt.t() dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() @@ -118,28 +128,39 @@ class MambaBlock(nn.Module): def step(self, hidden_states, conv_state, ssm_state): _xz = self.in_proj(hidden_states) _x, _z = _xz.chunk(2, dim=-1) # (B D) - conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1) - conv_out = causal_conv1d_fn( - x=conv_state_new, - weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), - bias=self.conv1d.bias, - activation=self.activation + conv_state_new = torch.cat([conv_state, _x.transpose(1, 2)], dim=-1) + conv_out = causal_conv1d_fn( + x=conv_state_new, + weight=self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ), + bias=self.conv1d.bias, + activation=self.activation, ) conv_state = conv_state_new[:, :, 1:] bsz, seqlen, dim = hidden_states.shape output_tensor = torch.zeros( - (bsz, seqlen, dim), - device=hidden_states.device, - dtype=hidden_states.dtype + (bsz, seqlen, dim), device=hidden_states.device, dtype=hidden_states.dtype ) for i in range(0, bsz): - x = conv_out[i:i+1,:,-1] - z = _z[i:i+1, -1, :] + x = conv_out[i : i + 1, :, -1] + z = _z[i : i + 1, -1, :] x_db = self.x_proj(x) - dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt, B, C = torch.split( + x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1 + ) dt = F.linear(dt, self.dt_proj.weight) y = selective_state_update( - ssm_state[i:i+1,:,:], x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True + ssm_state[i : i + 1, :, :], + x, + dt, + self.negA, + B, + C, + self.D, + z=z, + dt_bias=self.dt_proj.bias, + dt_softplus=True, ) out = self.out_proj(y) output_tensor[i] = out @@ -147,48 +168,70 @@ class MambaBlock(nn.Module): return output_tensor, conv_state, ssm_state - class ResidualBlock(nn.Module): def __init__(self, layer_id, config, weights): super().__init__() - self.mamba_block = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights) - self.layer_norm = FastRMSNorm.load(prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon) + self.mamba_block = MambaBlock( + prefix=f"{layer_id}.mixer", config=config, weights=weights + ) + self.layer_norm = FastRMSNorm.load( + prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon + ) def forward( self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None, inference_params: Optional[Any] = None, - ): + ): residual = (hidden_states + residual) if residual is not None else hidden_states shape = residual.shape hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1])) - hidden_states, conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params) + hidden_states, conv_state, last_ssm_state = self.mamba_block( + hidden_states.view(*shape), inference_params + ) return hidden_states, residual, conv_state, last_ssm_state + class MambaModel(nn.Module): def __init__(self, config, weights): super().__init__() prefix = "backbone" self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) self.blocks = nn.ModuleList( - [ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)] + [ + ResidualBlock(f"{prefix}.layers.{i}", config, weights) + for i in range(config.n_layer) + ] + ) + self.norm_f = FastRMSNorm.load( + f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon + ) + self.lm_head = FastLinear.load( + config, f"{prefix}.embedding", weights, bias=False ) - self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon) - self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False) self.config = config - def forward(self, input_ids: torch.Tensor, inference_params=None, residual=None) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]: + def forward( + self, input_ids: torch.Tensor, inference_params=None, residual=None + ) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]: hidden_states = self.embed_tokens(input_ids) for block in self.blocks: - hidden_states, residual, conv_state, ssm_state = block(hidden_states, residual, inference_params) - inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (conv_state, ssm_state) + hidden_states, residual, conv_state, ssm_state = block( + hidden_states, residual, inference_params + ) + inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = ( + conv_state, + ssm_state, + ) - hidden_states = hidden_states + residual if residual is not None else hidden_states + hidden_states = ( + hidden_states + residual if residual is not None else hidden_states + ) hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1))) hidden_states = hidden_states.view(residual.shape) logits = self.lm_head(hidden_states) # update the offset for the next inference using these params inference_params.seqlen_offset += input_ids.size(1) - return logits, input_ids, inference_params \ No newline at end of file + return logits, input_ids, inference_params diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 53a3d582..90776654 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -842,7 +842,6 @@ class FlashCausalLM(Model): else: next_token_logits = out - speculate = get_speculate() ( next_input_ids, @@ -1064,14 +1063,17 @@ class FlashCausalLM(Model): if top_n_tokens > 0: all_top_tokens = [] - for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs): + for (top_token_ids, top_token_logprobs) in zip( + top_token_ids, top_token_logprobs + ): toptoken_texts = self.tokenizer.batch_decode( top_token_ids, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) special_toptokens = [ - token_id in self.all_special_ids for token_id in top_token_ids + token_id in self.all_special_ids + for token_id in top_token_ids ] top_tokens = Tokens( top_token_ids, diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index c10910aa..c51e1e20 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -26,6 +26,7 @@ from dataclasses import dataclass from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from mamba_ssm.utils.generation import InferenceParams + @dataclass class MambaBatch(Batch): batch_id: int @@ -69,7 +70,7 @@ class MambaBatch(Batch): size=len(self), max_tokens=self.max_tokens, ) - + @classmethod def from_pb( cls, @@ -196,7 +197,7 @@ class MambaBatch(Batch): new_padding_right_offset = max( new_padding_right_offset, remaining_decode_tokens ) - + # Apply indices to input_ids, attention mask, past key values and other items that need to be cached input_ids = self.input_ids[keep_indices] @@ -218,10 +219,13 @@ class MambaBatch(Batch): self.padding_right_offset = new_padding_right_offset self.max_tokens = max_tokens - # TODO + # TODO # Kept it simple by just updating the state, maybe updating the other CPU values is necessary. key_value_memory_dict = {} - for i, (conv_state, ssm_state) in self.inference_params.key_value_memory_dict.items(): + for i, ( + conv_state, + ssm_state, + ) in self.inference_params.key_value_memory_dict.items(): key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices]) self.inference_params.key_value_memory_dict = key_value_memory_dict @@ -305,8 +309,9 @@ class MambaBatch(Batch): start_index = end_index - - (_, d_model, d_conv) = batches[0].inference_params.key_value_memory_dict[0][0].shape + (_, d_model, d_conv) = ( + batches[0].inference_params.key_value_memory_dict[0][0].shape + ) (_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape n_blocks = len(batches[0].inference_params.key_value_memory_dict) dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype @@ -344,9 +349,15 @@ class MambaBatch(Batch): for i in range(n_blocks): conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i] batch_size = batch.inference_params.max_batch_size - inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state - inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state - inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample + inference_params.key_value_memory_dict[i][0][ + current_batch : current_batch + batch_size + ] = conv_state + inference_params.key_value_memory_dict[i][1][ + current_batch : current_batch + batch_size + ] = ssm_state + inference_params.lengths_per_sample[ + current_batch : current_batch + batch_size + ] = batch.inference_params.lengths_per_sample current_batch += batch_size return cls( @@ -366,12 +377,13 @@ class MambaBatch(Batch): padding_right_offset=padding_right_offset, keys_head_dim_last=batches[0].keys_head_dim_last, max_tokens=max_tokens, - inference_params=inference_params + inference_params=inference_params, ) def __len__(self): return len(self.requests) + class Mamba(Model): def __init__( self, @@ -428,7 +440,7 @@ class Mamba(Model): def warmup(self, batch) -> Optional[int]: # TODO: implement warmup for Mamba if needed return None - + def forward( self, input_ids: torch.Tensor, @@ -441,7 +453,9 @@ class Mamba(Model): def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: start = time.time_ns() - input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids + input_ids = ( + batch.input_ids + ) # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids batch_size = input_ids.shape[0] max_seqlen = input_ids.shape[1] @@ -450,8 +464,11 @@ class Mamba(Model): # Inference params seqlen_og = 0 inf_cache = {} - lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen - + lengths_per_sample = ( + torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) + * max_seqlen + ) + if batch.inference_params is None: inference_params = InferenceParams( max_seqlen=max_seqlen, @@ -478,11 +495,16 @@ class Mamba(Model): device=block.dt_proj.weight.device, dtype=block.dt_proj.weight.dtype, ) - inference_params.key_value_memory_dict[block.layer_idx] = (conv_state, ssm_state) + inference_params.key_value_memory_dict[block.layer_idx] = ( + conv_state, + ssm_state, + ) batch.inference_params = inference_params - + # Forward pass - logits, past_input_ids, new_inference_params = self.model(input_ids, batch.inference_params) + logits, past_input_ids, new_inference_params = self.model( + input_ids, batch.inference_params + ) batch.inference_params = new_inference_params # Results @@ -564,7 +586,8 @@ class Mamba(Model): prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, - read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, skip_special_tokens=True, ) # Get seed diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 8b93aecd..25042a32 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -750,14 +750,17 @@ class Seq2SeqLM(Model): if top_n_tokens > 0: all_top_tokens = [] - for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs): + for (top_token_ids, top_token_logprobs) in zip( + top_token_ids, top_token_logprobs + ): toptoken_texts = self.tokenizer.batch_decode( top_token_ids, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) special_toptokens = [ - token_id in self.all_special_ids for token_id in top_token_ids + token_id in self.all_special_ids + for token_id in top_token_ids ] top_tokens = Tokens( top_token_ids, diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index bc68812e..da71b0ec 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -95,5 +95,7 @@ class Generation: generated_text=self.generated_text.to_pb() if self.generated_text is not None else None, - top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens] if self.top_tokens is not None else None, + top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens] + if self.top_tokens is not None + else None, ) diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index f424eae4..291c522f 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -118,6 +118,62 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): return None +class FrequencyPenaltyLogitsProcessor(LogitsProcessor): + r""" + Frequency penalty as defined by OpenAI + + Args: + penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty. + """ + + def __init__(self, penalty: float): + self.penalty = penalty + + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + score = torch.gather(scores, 1, input_ids) + # if score < 0 then penalty has to be multiplied to reduce the previous token probability + score = -torch.where( + score < 0, score * self.penalty, score / self.penalty + ) + + return scores.scatter_add_(1, input_ids, score) + + +class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor): + r""" + Frequency penalty as defined by OpenAI + + Args: + frequency_penalty (`List[float]`): + The parameter for frequency penalty. 0.0 means no penalty. + """ + + def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device): + self.penalty = penalty + self.penalty_tensor = torch.tensor( + penalty, dtype=dtype, device=device + ).unsqueeze(1) + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + score = torch.gather(scores, 1, input_ids) + # if score < 0 then penalty has to be multiplied to reduce the previous token probability + score = -torch.where( + score < 0, score * self.penalty_tensor, score / self.penalty_tensor + ) + + return scores.scatter_add_(1, input_ids, score) + + def filter(self, indices): + self.penalty = [self.penalty[i] for i in indices] + if any([x != 0.0 for x in self.penalty]): + self.penalty_tensor = self.penalty_tensor[indices] + return self + return None + + class HeterogeneousTemperatureLogitsWarper: r""" [`LogitsWarper`] for temperature (exponential scaling output probability distribution). diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 270a6990..d6ca10c7 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,12 +1,14 @@ import re -from typing import Callable, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from text_generation_server.pb import generate_pb2 from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.utils.logits_process import ( + FrequencyPenaltyLogitsProcessor, HeterogeneousProcessorWrapper, HeterogeneousRepetitionPenaltyLogitsProcessor, + HeterogeneousFrequencyPenaltyLogitsProcessor, HeterogeneousTemperatureLogitsWarper, HeterogeneousTopKLogitsWarper, HeterogeneousTopPLogitsWarper, @@ -23,6 +25,7 @@ class NextTokenChooser: watermark=False, temperature=1.0, repetition_penalty=1.0, + frequency_penalty=0.0, top_k=None, top_p=None, typical_p=None, @@ -35,7 +38,12 @@ class NextTokenChooser: ) self.repetition_processor = ( RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) - if repetition_penalty + if repetition_penalty and repetition_penalty != 1.0 + else None + ) + self.frequency_processor = ( + FrequencyPenaltyLogitsProcessor(penalty=frequency_penalty) + if frequency_penalty and frequency_penalty != 0.0 else None ) @@ -60,6 +68,8 @@ class NextTokenChooser: scores = self.watermark_processor(input_ids, scores) if self.repetition_processor is not None: scores = self.repetition_processor(input_ids, scores) + if self.frequency_processor is not None: + scores = self.frequency_processor(input_ids, scores) if self.static_warper is None: next_logprob = torch.log_softmax(scores, -1) @@ -80,6 +90,7 @@ class NextTokenChooser: watermark=pb.watermark, temperature=pb.temperature, repetition_penalty=pb.repetition_penalty, + frequency_penalty=pb.frequency_penalty, top_k=pb.top_k, top_p=pb.top_p, typical_p=pb.typical_p, @@ -184,6 +195,7 @@ class HeterogeneousNextTokenChooser: watermark: List[bool], temperature: List[float], repetition_penalty: List[float], + frequency_penalty: List[float], top_k: List[int], top_p: List[float], typical_p: List[float], @@ -212,6 +224,14 @@ class HeterogeneousNextTokenChooser: else None ) + self.frequency_processor = ( + HeterogeneousFrequencyPenaltyLogitsProcessor( + frequency_penalty, dtype, device + ) + if any([x != 0.0 for x in frequency_penalty]) + else None + ) + if any([x != 1.0 for x in temperature]): do_sample = [ sample or x != 1.0 for x, sample in zip(temperature, do_sample) @@ -269,6 +289,8 @@ class HeterogeneousNextTokenChooser: _scores = self.watermark_processor(input_ids, _scores) if self.repetition_processor is not None: _scores = self.repetition_processor(input_ids, _scores) + if self.frequency_processor is not None: + _scores = self.frequency_processor(input_ids, _scores) for warper in self.warpers: _scores = warper(input_ids, _scores) @@ -316,7 +338,6 @@ class HeterogeneousNextTokenChooser: next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) - if speculate > 0: if speculative_scores is not None: # Medusa provided some scores @@ -338,6 +359,9 @@ class HeterogeneousNextTokenChooser: if self.repetition_processor is not None: self.repetition_processor = self.repetition_processor.filter(indices) + if self.frequency_processor is not None: + self.frequency_processor = self.frequency_processor.filter(indices) + filtered_warpers = [] for warper in self.warpers: filtered_warper = warper.filter(indices) @@ -366,6 +390,7 @@ class HeterogeneousNextTokenChooser: watermark=[pb_.watermark for pb_ in pb], temperature=[pb_.temperature for pb_ in pb], repetition_penalty=[pb_.repetition_penalty for pb_ in pb], + frequency_penalty=[pb_.frequency_penalty for pb_ in pb], top_k=[pb_.top_k for pb_ in pb], top_p=[pb_.top_p for pb_ in pb], typical_p=[pb_.typical_p for pb_ in pb], @@ -438,7 +463,10 @@ class HeterogeneousSampling: def batch_top_tokens( - top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor + top_n_tokens: List[int], + top_n_tokens_tensor: torch.Tensor, + logprobs: torch.Tensor, + accepted_ids: torch.Tensor, ) -> Tuple[List[List[List[int]]], List[List[List[float]]]]: """Find the top n most likely tokens for a batch of generations. @@ -450,12 +478,15 @@ def batch_top_tokens( if max_top_n == 0: return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens) - batch_size = accepted_ids.shape[0] speculate_size = logprobs.shape[0] // batch_size top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size) # Ensure top_n doesn't exceed vocab size - top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculate_size)] + top_n_tokens = [ + min(tok, logprobs.size(-1)) + for tok in top_n_tokens + for _ in range(speculate_size) + ] # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2 # Sorted topk is faster than torch.sort() since we only need a small subset @@ -484,10 +515,10 @@ def batch_top_tokens( for i, n_accepted_ids in enumerate(accepted_ids_list): start = speculate_size * i stop = speculate_size * (i + 1) - _top_indices = top_indices[start: stop] - _top_values = top_values[start: stop] - _top_n_ishes = top_n_ishes[start: stop] - _top_n_tokens = top_n_tokens[start: stop] + _top_indices = top_indices[start:stop] + _top_values = top_values[start:stop] + _top_n_ishes = top_n_ishes[start:stop] + _top_n_tokens = top_n_tokens[start:stop] _top_indices = _top_indices[:n_accepted_ids] _top_values = _top_values[:n_accepted_ids] @@ -497,7 +528,9 @@ def batch_top_tokens( row_top_token_ids = [] row_top_token_logprobs = [] - for idxs, vals, n, req_n in zip(_top_indices, _top_values, _top_n_ishes, _top_n_tokens): + for idxs, vals, n, req_n in zip( + _top_indices, _top_values, _top_n_ishes, _top_n_tokens + ): indices = idxs[:n] if req_n > 0 else [] values = vals[:n] if req_n > 0 else []