feat(server): add frequency penalty (#1541)

This commit is contained in:
OlivierDehaene 2024-02-08 18:41:25 +01:00 committed by GitHub
parent 39af000cb9
commit 09b7c26bbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 396 additions and 89 deletions

38
Cargo.lock generated
View File

@ -2787,7 +2787,7 @@ dependencies = [
"tabled",
"text-generation-client",
"thiserror",
"tokenizers",
"tokenizers 0.14.1",
"tokio",
"tracing",
"tracing-subscriber",
@ -2850,7 +2850,7 @@ dependencies = [
"serde_json",
"text-generation-client",
"thiserror",
"tokenizers",
"tokenizers 0.15.1",
"tokio",
"tokio-stream",
"tower-http",
@ -2972,6 +2972,40 @@ dependencies = [
"unicode_categories",
]
[[package]]
name = "tokenizers"
version = "0.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6db445cceba5dfeb0f9702be7d6bfd91801ddcbe8fe8722defe7f2e96da75812"
dependencies = [
"aho-corasick",
"clap",
"derive_builder",
"esaxx-rs",
"getrandom",
"hf-hub",
"indicatif",
"itertools 0.11.0",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand",
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.7.5",
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]]
name = "tokio"
version = "1.35.1"

View File

@ -30,6 +30,7 @@ pub async fn run(
top_p: Option<f32>,
typical_p: Option<f32>,
repetition_penalty: Option<f32>,
frequency_penalty: Option<f32>,
watermark: bool,
do_sample: bool,
client: ShardedClient,
@ -42,6 +43,7 @@ pub async fn run(
do_sample,
seed: 0,
repetition_penalty: repetition_penalty.unwrap_or(1.0),
frequency_penalty: frequency_penalty.unwrap_or(0.0),
watermark,
};
@ -140,6 +142,7 @@ pub async fn run(
top_p,
typical_p,
repetition_penalty,
frequency_penalty,
watermark,
do_sample,
);

View File

@ -84,6 +84,11 @@ struct Args {
#[clap(long, env)]
repetition_penalty: Option<f32>,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
frequency_penalty: Option<f32>,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
@ -119,6 +124,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
top_p,
typical_p,
repetition_penalty,
frequency_penalty,
watermark,
do_sample,
master_shard_uds_path,
@ -187,6 +193,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
top_p,
typical_p,
repetition_penalty,
frequency_penalty,
watermark,
do_sample,
sharded_client,

View File

@ -15,6 +15,7 @@ pub(crate) fn parameters_table(
top_p: Option<f32>,
typical_p: Option<f32>,
repetition_penalty: Option<f32>,
frequency_penalty: Option<f32>,
watermark: bool,
do_sample: bool,
) -> Table {
@ -33,6 +34,7 @@ pub(crate) fn parameters_table(
builder.push_record(["Top P", &format!("{top_p:?}")]);
builder.push_record(["Typical P", &format!("{typical_p:?}")]);
builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]);
builder.push_record(["Frequency Penalty", &format!("{frequency_penalty:?}")]);
builder.push_record(["Watermark", &watermark.to_string()]);
builder.push_record(["Do Sample", &do_sample.to_string()]);

View File

@ -24,6 +24,7 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
assert response.generated_text == "\n\nDeep learning is a new type of machine"
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
@ -44,13 +45,19 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
)
assert response.details.generated_tokens == 10
assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in"
assert (
response.generated_text
== "blue, red, yellow, \nand orange (in the order they appear in"
)
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
responses = await generate_load(fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4)
responses = await generate_load(
fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])

View File

@ -66,6 +66,8 @@ message NextTokenChooserParameters {
uint64 seed = 6;
/// repetition penalty
float repetition_penalty = 7;
/// frequency penalty
float frequency_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
}

View File

@ -125,6 +125,7 @@ impl Client {
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
frequency_penalty: 0.1,
watermark: true,
}),
stopping_parameters: Some(StoppingCriteriaParameters {

View File

@ -43,6 +43,7 @@ impl Health {
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
}),
stopping_parameters: Some(StoppingCriteriaParameters {

View File

@ -106,6 +106,14 @@ pub(crate) struct GenerateParameters {
)]
pub repetition_penalty: Option<f32>,
#[serde(default)]
#[schema(
exclusive_minimum = -2.0,
nullable = true,
default = "null",
example = 0.1
)]
pub frequency_penalty: Option<f32>,
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
pub top_k: Option<i32>,
#[serde(default)]
@ -172,6 +180,7 @@ fn default_parameters() -> GenerateParameters {
best_of: None,
temperature: None,
repetition_penalty: None,
frequency_penalty: None,
top_k: None,
top_p: None,
typical_p: None,
@ -205,10 +214,71 @@ pub(crate) struct ChatCompletion {
pub(crate) struct ChatCompletionComplete {
pub index: u32,
pub message: Message,
pub logprobs: Option<Vec<f32>>,
pub logprobs: Option<ChatCompletionLogprobs>,
pub finish_reason: String,
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionLogprobs {
content: Vec<ChatCompletionLogprob>,
}
impl From<(Token, Vec<Token>)> for ChatCompletionLogprobs {
fn from(value: (Token, Vec<Token>)) -> Self {
let (token, top_tokens) = value;
Self {
content: vec![ChatCompletionLogprob {
token: token.text,
logprob: token.logprob,
top_logprobs: top_tokens
.into_iter()
.map(|t| ChatCompletionTopLogprob {
token: t.text,
logprob: t.logprob,
})
.collect(),
}],
}
}
}
impl From<(Vec<Token>, Vec<Vec<Token>>)> for ChatCompletionLogprobs {
fn from(value: (Vec<Token>, Vec<Vec<Token>>)) -> Self {
let (tokens, top_tokens) = value;
Self {
content: tokens
.into_iter()
.zip(top_tokens)
.map(|(t, top_t)| ChatCompletionLogprob {
token: t.text,
logprob: t.logprob,
top_logprobs: top_t
.into_iter()
.map(|t| ChatCompletionTopLogprob {
token: t.text,
logprob: t.logprob,
})
.collect(),
})
.collect(),
}
}
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionLogprob {
token: String,
logprob: f32,
top_logprobs: Vec<ChatCompletionTopLogprob>,
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionTopLogprob {
token: String,
logprob: f32,
}
#[derive(Clone, Deserialize, Serialize)]
pub(crate) struct Usage {
pub prompt_tokens: u32,
@ -238,7 +308,7 @@ impl ChatCompletion {
content: output,
},
logprobs: return_logprobs
.then(|| details.tokens.iter().map(|t| t.logprob).collect()),
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
finish_reason: details.finish_reason.to_string(),
}],
usage: Usage {
@ -266,7 +336,7 @@ pub(crate) struct ChatCompletionChunk {
pub(crate) struct ChatCompletionChoice {
pub index: u32,
pub delta: ChatCompletionDelta,
pub logprobs: Option<f32>,
pub logprobs: Option<ChatCompletionLogprobs>,
pub finish_reason: Option<String>,
}
@ -285,7 +355,7 @@ impl ChatCompletionChunk {
delta: String,
created: u64,
index: u32,
logprobs: Option<f32>,
logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>,
) -> Self {
Self {
@ -319,8 +389,8 @@ pub(crate) struct ChatRequest {
/// UNUSED
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String, /* NOTE: UNUSED */
pub model: String,
/* NOTE: UNUSED */
/// A list of messages comprising the conversation so far.
#[serde(default = "default_request_messages")]
pub messages: Vec<Message>,
@ -346,7 +416,6 @@ pub(crate) struct ChatRequest {
#[schema(example = "false")]
pub logprobs: Option<bool>,
/// UNUSED
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
/// an associated log probability. logprobs must be set to true if this parameter is used.
#[serde(default)]
@ -365,7 +434,6 @@ pub(crate) struct ChatRequest {
#[schema(nullable = true, example = "2")]
pub n: Option<u32>,
/// UNUSED
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics
#[serde(default)]
@ -447,7 +515,7 @@ pub struct PrefillToken {
logprob: f32,
}
#[derive(Debug, Serialize, ToSchema)]
#[derive(Debug, Serialize, ToSchema, Clone)]
pub struct Token {
#[schema(example = 0)]
id: u32,

View File

@ -355,6 +355,7 @@ mod tests {
do_sample: false,
seed: 0,
repetition_penalty: 0.0,
frequency_penalty: 0.0,
watermark: false,
},
stopping_parameters: StoppingCriteriaParameters {

View File

@ -4,9 +4,10 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
ChatRequest, CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
GenerateRequest, GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message,
PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Validation,
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
StreamResponse, Token, TokenizeResponse, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
@ -570,8 +571,8 @@ async fn chat_completions(
let stream = req.stream;
let max_new_tokens = req.max_tokens.or(Some(100));
let repetition_penalty = req
.frequency_penalty
// rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0)
.presence_penalty
// rescale repetition_penalty from (-2.0, 2.0) to (0.0, 4.0)
.map(|x| x + 2.0);
let logprobs = req.logprobs.unwrap_or(false);
let seed = req.seed;
@ -599,6 +600,7 @@ async fn chat_completions(
best_of: None,
temperature: req.temperature,
repetition_penalty,
frequency_penalty: req.frequency_penalty,
top_k: None,
top_p: req.top_p,
typical_p: None,
@ -630,6 +632,10 @@ async fn chat_completions(
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let logprobs = logprobs.then(|| {
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
});
event
.json_data(ChatCompletionChunk::new(
model_id.clone(),
@ -637,7 +643,7 @@ async fn chat_completions(
stream_token.token.text,
current_time,
stream_token.index,
logprobs.then_some(stream_token.token.logprob),
logprobs,
stream_token.details.map(|d| d.finish_reason.to_string()),
))
.map_or_else(

View File

@ -170,6 +170,7 @@ impl Validation {
best_of,
temperature,
repetition_penalty,
frequency_penalty,
top_k,
top_p,
typical_p,
@ -206,6 +207,11 @@ impl Validation {
return Err(ValidationError::RepetitionPenalty);
}
let frequency_penalty = frequency_penalty.unwrap_or(0.0);
if !(-2.0..=2.0).contains(&frequency_penalty) {
return Err(ValidationError::FrequencyPenalty);
}
// Different because the proto default value is not a valid value
// for the user
let top_p = top_p
@ -289,6 +295,7 @@ impl Validation {
let parameters = NextTokenChooserParameters {
temperature,
repetition_penalty,
frequency_penalty,
top_k,
top_p,
typical_p,
@ -420,6 +427,8 @@ pub enum ValidationError {
Temperature,
#[error("`repetition_penalty` must be strictly positive")]
RepetitionPenalty,
#[error("`frequency_penalty` must be >= -2.0 and <= 2.0")]
FrequencyPenalty,
#[error("`top_p` must be > 0.0 and < 1.0")]
TopP,
#[error("`top_k` must be strictly positive")]

View File

@ -70,7 +70,7 @@ def test_batch_top_tokens():
# Now let's make second member of the batch be speculated
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2)
accepted_ids[1] = 2
accepted_ids[1] = 2
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
)

View File

@ -86,6 +86,7 @@ except ImportError as e:
if MAMBA_AVAILABLE:
__all__.append(Mamba)
def get_model(
model_id: str,
revision: Optional[str],

View File

@ -696,14 +696,17 @@ class CausalLM(Model):
if top_n_tokens > 0:
all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
for (top_token_ids, top_token_logprobs) in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
token_id in self.all_special_ids
for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,

View File

@ -19,6 +19,7 @@ from einops import rearrange
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
import math
class MambaConfig(PretrainedConfig):
def __init__(
self,
@ -53,6 +54,7 @@ class MambaConfig(PretrainedConfig):
**kwargs,
)
class MambaBlock(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
@ -60,10 +62,14 @@ class MambaBlock(nn.Module):
self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False)
self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False)
self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True)
self.dt_proj_no_bias = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=False)
self.out_proj = FastLinear.load(config, f"{prefix}.out_proj", weights, bias=False)
self.dt_proj_no_bias = FastLinear.load(
config, f"{prefix}.dt_proj", weights, bias=False
)
self.out_proj = FastLinear.load(
config, f"{prefix}.out_proj", weights, bias=False
)
self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True)
self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float())
self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float())
self.D = weights.get_tensor(f"{prefix}.D")
self.activation = "silu"
self.dt_rank = config.dt_rank
@ -80,12 +86,14 @@ class MambaBlock(nn.Module):
out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state)
return out, conv_state, ssm_state
projected_states = self.in_proj(hidden_states).transpose(1,2)
projected_states = self.in_proj(hidden_states).transpose(1, 2)
x, z = projected_states.chunk(2, dim=1)
conv_state = F.pad(x, (self.d_conv - seqlen, 0))
x = causal_conv1d_fn(
x=x,
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
weight=self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
),
bias=self.conv1d.bias,
activation=self.activation,
)
@ -94,7 +102,9 @@ class MambaBlock(nn.Module):
# We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt, B, C = torch.split(
x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
dt = self.dt_proj.weight @ dt.t()
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
@ -118,28 +128,39 @@ class MambaBlock(nn.Module):
def step(self, hidden_states, conv_state, ssm_state):
_xz = self.in_proj(hidden_states)
_x, _z = _xz.chunk(2, dim=-1) # (B D)
conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1)
conv_out = causal_conv1d_fn(
x=conv_state_new,
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
bias=self.conv1d.bias,
activation=self.activation
conv_state_new = torch.cat([conv_state, _x.transpose(1, 2)], dim=-1)
conv_out = causal_conv1d_fn(
x=conv_state_new,
weight=self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
),
bias=self.conv1d.bias,
activation=self.activation,
)
conv_state = conv_state_new[:, :, 1:]
bsz, seqlen, dim = hidden_states.shape
output_tensor = torch.zeros(
(bsz, seqlen, dim),
device=hidden_states.device,
dtype=hidden_states.dtype
(bsz, seqlen, dim), device=hidden_states.device, dtype=hidden_states.dtype
)
for i in range(0, bsz):
x = conv_out[i:i+1,:,-1]
z = _z[i:i+1, -1, :]
x = conv_out[i : i + 1, :, -1]
z = _z[i : i + 1, -1, :]
x_db = self.x_proj(x)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt, B, C = torch.split(
x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
dt = F.linear(dt, self.dt_proj.weight)
y = selective_state_update(
ssm_state[i:i+1,:,:], x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
ssm_state[i : i + 1, :, :],
x,
dt,
self.negA,
B,
C,
self.D,
z=z,
dt_bias=self.dt_proj.bias,
dt_softplus=True,
)
out = self.out_proj(y)
output_tensor[i] = out
@ -147,48 +168,70 @@ class MambaBlock(nn.Module):
return output_tensor, conv_state, ssm_state
class ResidualBlock(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
self.mamba_block = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights)
self.layer_norm = FastRMSNorm.load(prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon)
self.mamba_block = MambaBlock(
prefix=f"{layer_id}.mixer", config=config, weights=weights
)
self.layer_norm = FastRMSNorm.load(
prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon
)
def forward(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None,
inference_params: Optional[Any] = None,
):
):
residual = (hidden_states + residual) if residual is not None else hidden_states
shape = residual.shape
hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1]))
hidden_states, conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params)
hidden_states, conv_state, last_ssm_state = self.mamba_block(
hidden_states.view(*shape), inference_params
)
return hidden_states, residual, conv_state, last_ssm_state
class MambaModel(nn.Module):
def __init__(self, config, weights):
super().__init__()
prefix = "backbone"
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
self.blocks = nn.ModuleList(
[ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)]
[
ResidualBlock(f"{prefix}.layers.{i}", config, weights)
for i in range(config.n_layer)
]
)
self.norm_f = FastRMSNorm.load(
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
)
self.lm_head = FastLinear.load(
config, f"{prefix}.embedding", weights, bias=False
)
self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon)
self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False)
self.config = config
def forward(self, input_ids: torch.Tensor, inference_params=None, residual=None) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
def forward(
self, input_ids: torch.Tensor, inference_params=None, residual=None
) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
hidden_states = self.embed_tokens(input_ids)
for block in self.blocks:
hidden_states, residual, conv_state, ssm_state = block(hidden_states, residual, inference_params)
inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (conv_state, ssm_state)
hidden_states, residual, conv_state, ssm_state = block(
hidden_states, residual, inference_params
)
inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (
conv_state,
ssm_state,
)
hidden_states = hidden_states + residual if residual is not None else hidden_states
hidden_states = (
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
hidden_states = hidden_states.view(residual.shape)
logits = self.lm_head(hidden_states)
# update the offset for the next inference using these params
inference_params.seqlen_offset += input_ids.size(1)
return logits, input_ids, inference_params
return logits, input_ids, inference_params

View File

@ -842,7 +842,6 @@ class FlashCausalLM(Model):
else:
next_token_logits = out
speculate = get_speculate()
(
next_input_ids,
@ -1064,14 +1063,17 @@ class FlashCausalLM(Model):
if top_n_tokens > 0:
all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
for (top_token_ids, top_token_logprobs) in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
token_id in self.all_special_ids
for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,

View File

@ -26,6 +26,7 @@ from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from mamba_ssm.utils.generation import InferenceParams
@dataclass
class MambaBatch(Batch):
batch_id: int
@ -69,7 +70,7 @@ class MambaBatch(Batch):
size=len(self),
max_tokens=self.max_tokens,
)
@classmethod
def from_pb(
cls,
@ -196,7 +197,7 @@ class MambaBatch(Batch):
new_padding_right_offset = max(
new_padding_right_offset, remaining_decode_tokens
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices]
@ -218,10 +219,13 @@ class MambaBatch(Batch):
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
# TODO
# TODO
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
key_value_memory_dict = {}
for i, (conv_state, ssm_state) in self.inference_params.key_value_memory_dict.items():
for i, (
conv_state,
ssm_state,
) in self.inference_params.key_value_memory_dict.items():
key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices])
self.inference_params.key_value_memory_dict = key_value_memory_dict
@ -305,8 +309,9 @@ class MambaBatch(Batch):
start_index = end_index
(_, d_model, d_conv) = batches[0].inference_params.key_value_memory_dict[0][0].shape
(_, d_model, d_conv) = (
batches[0].inference_params.key_value_memory_dict[0][0].shape
)
(_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape
n_blocks = len(batches[0].inference_params.key_value_memory_dict)
dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype
@ -344,9 +349,15 @@ class MambaBatch(Batch):
for i in range(n_blocks):
conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i]
batch_size = batch.inference_params.max_batch_size
inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state
inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state
inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample
inference_params.key_value_memory_dict[i][0][
current_batch : current_batch + batch_size
] = conv_state
inference_params.key_value_memory_dict[i][1][
current_batch : current_batch + batch_size
] = ssm_state
inference_params.lengths_per_sample[
current_batch : current_batch + batch_size
] = batch.inference_params.lengths_per_sample
current_batch += batch_size
return cls(
@ -366,12 +377,13 @@ class MambaBatch(Batch):
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens,
inference_params=inference_params
inference_params=inference_params,
)
def __len__(self):
return len(self.requests)
class Mamba(Model):
def __init__(
self,
@ -428,7 +440,7 @@ class Mamba(Model):
def warmup(self, batch) -> Optional[int]:
# TODO: implement warmup for Mamba if needed
return None
def forward(
self,
input_ids: torch.Tensor,
@ -441,7 +453,9 @@ class Mamba(Model):
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
start = time.time_ns()
input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
input_ids = (
batch.input_ids
) # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
batch_size = input_ids.shape[0]
max_seqlen = input_ids.shape[1]
@ -450,8 +464,11 @@ class Mamba(Model):
# Inference params
seqlen_og = 0
inf_cache = {}
lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen
lengths_per_sample = (
torch.ones(batch_size, dtype=torch.int32, device=input_ids.device)
* max_seqlen
)
if batch.inference_params is None:
inference_params = InferenceParams(
max_seqlen=max_seqlen,
@ -478,11 +495,16 @@ class Mamba(Model):
device=block.dt_proj.weight.device,
dtype=block.dt_proj.weight.dtype,
)
inference_params.key_value_memory_dict[block.layer_idx] = (conv_state, ssm_state)
inference_params.key_value_memory_dict[block.layer_idx] = (
conv_state,
ssm_state,
)
batch.inference_params = inference_params
# Forward pass
logits, past_input_ids, new_inference_params = self.model(input_ids, batch.inference_params)
logits, past_input_ids, new_inference_params = self.model(
input_ids, batch.inference_params
)
batch.inference_params = new_inference_params
# Results
@ -564,7 +586,8 @@ class Mamba(Model):
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
)
# Get seed

View File

@ -750,14 +750,17 @@ class Seq2SeqLM(Model):
if top_n_tokens > 0:
all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
for (top_token_ids, top_token_logprobs) in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
token_id in self.all_special_ids
for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,

View File

@ -95,5 +95,7 @@ class Generation:
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens] if self.top_tokens is not None else None,
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens]
if self.top_tokens is not None
else None,
)

View File

@ -118,6 +118,62 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
return None
class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
r"""
Frequency penalty as defined by OpenAI
Args:
penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty.
"""
def __init__(self, penalty: float):
self.penalty = penalty
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where(
score < 0, score * self.penalty, score / self.penalty
)
return scores.scatter_add_(1, input_ids, score)
class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
r"""
Frequency penalty as defined by OpenAI
Args:
frequency_penalty (`List[float]`):
The parameter for frequency penalty. 0.0 means no penalty.
"""
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
self.penalty = penalty
self.penalty_tensor = torch.tensor(
penalty, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where(
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
)
return scores.scatter_add_(1, input_ids, score)
def filter(self, indices):
self.penalty = [self.penalty[i] for i in indices]
if any([x != 0.0 for x in self.penalty]):
self.penalty_tensor = self.penalty_tensor[indices]
return self
return None
class HeterogeneousTemperatureLogitsWarper:
r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).

View File

@ -1,12 +1,14 @@
import re
from typing import Callable, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.logits_process import (
FrequencyPenaltyLogitsProcessor,
HeterogeneousProcessorWrapper,
HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousFrequencyPenaltyLogitsProcessor,
HeterogeneousTemperatureLogitsWarper,
HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper,
@ -23,6 +25,7 @@ class NextTokenChooser:
watermark=False,
temperature=1.0,
repetition_penalty=1.0,
frequency_penalty=0.0,
top_k=None,
top_p=None,
typical_p=None,
@ -35,7 +38,12 @@ class NextTokenChooser:
)
self.repetition_processor = (
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
if repetition_penalty
if repetition_penalty and repetition_penalty != 1.0
else None
)
self.frequency_processor = (
FrequencyPenaltyLogitsProcessor(penalty=frequency_penalty)
if frequency_penalty and frequency_penalty != 0.0
else None
)
@ -60,6 +68,8 @@ class NextTokenChooser:
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores)
if self.frequency_processor is not None:
scores = self.frequency_processor(input_ids, scores)
if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1)
@ -80,6 +90,7 @@ class NextTokenChooser:
watermark=pb.watermark,
temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,
frequency_penalty=pb.frequency_penalty,
top_k=pb.top_k,
top_p=pb.top_p,
typical_p=pb.typical_p,
@ -184,6 +195,7 @@ class HeterogeneousNextTokenChooser:
watermark: List[bool],
temperature: List[float],
repetition_penalty: List[float],
frequency_penalty: List[float],
top_k: List[int],
top_p: List[float],
typical_p: List[float],
@ -212,6 +224,14 @@ class HeterogeneousNextTokenChooser:
else None
)
self.frequency_processor = (
HeterogeneousFrequencyPenaltyLogitsProcessor(
frequency_penalty, dtype, device
)
if any([x != 0.0 for x in frequency_penalty])
else None
)
if any([x != 1.0 for x in temperature]):
do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
@ -269,6 +289,8 @@ class HeterogeneousNextTokenChooser:
_scores = self.watermark_processor(input_ids, _scores)
if self.repetition_processor is not None:
_scores = self.repetition_processor(input_ids, _scores)
if self.frequency_processor is not None:
_scores = self.frequency_processor(input_ids, _scores)
for warper in self.warpers:
_scores = warper(input_ids, _scores)
@ -316,7 +338,6 @@ class HeterogeneousNextTokenChooser:
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
if speculate > 0:
if speculative_scores is not None:
# Medusa provided some scores
@ -338,6 +359,9 @@ class HeterogeneousNextTokenChooser:
if self.repetition_processor is not None:
self.repetition_processor = self.repetition_processor.filter(indices)
if self.frequency_processor is not None:
self.frequency_processor = self.frequency_processor.filter(indices)
filtered_warpers = []
for warper in self.warpers:
filtered_warper = warper.filter(indices)
@ -366,6 +390,7 @@ class HeterogeneousNextTokenChooser:
watermark=[pb_.watermark for pb_ in pb],
temperature=[pb_.temperature for pb_ in pb],
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
top_k=[pb_.top_k for pb_ in pb],
top_p=[pb_.top_p for pb_ in pb],
typical_p=[pb_.typical_p for pb_ in pb],
@ -438,7 +463,10 @@ class HeterogeneousSampling:
def batch_top_tokens(
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor
top_n_tokens: List[int],
top_n_tokens_tensor: torch.Tensor,
logprobs: torch.Tensor,
accepted_ids: torch.Tensor,
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
"""Find the top n most likely tokens for a batch of generations.
@ -450,12 +478,15 @@ def batch_top_tokens(
if max_top_n == 0:
return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
batch_size = accepted_ids.shape[0]
speculate_size = logprobs.shape[0] // batch_size
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
# Ensure top_n doesn't exceed vocab size
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculate_size)]
top_n_tokens = [
min(tok, logprobs.size(-1))
for tok in top_n_tokens
for _ in range(speculate_size)
]
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset
@ -484,10 +515,10 @@ def batch_top_tokens(
for i, n_accepted_ids in enumerate(accepted_ids_list):
start = speculate_size * i
stop = speculate_size * (i + 1)
_top_indices = top_indices[start: stop]
_top_values = top_values[start: stop]
_top_n_ishes = top_n_ishes[start: stop]
_top_n_tokens = top_n_tokens[start: stop]
_top_indices = top_indices[start:stop]
_top_values = top_values[start:stop]
_top_n_ishes = top_n_ishes[start:stop]
_top_n_tokens = top_n_tokens[start:stop]
_top_indices = _top_indices[:n_accepted_ids]
_top_values = _top_values[:n_accepted_ids]
@ -497,7 +528,9 @@ def batch_top_tokens(
row_top_token_ids = []
row_top_token_logprobs = []
for idxs, vals, n, req_n in zip(_top_indices, _top_values, _top_n_ishes, _top_n_tokens):
for idxs, vals, n, req_n in zip(
_top_indices, _top_values, _top_n_ishes, _top_n_tokens
):
indices = idxs[:n] if req_n > 0 else []
values = vals[:n] if req_n > 0 else []