feat(server): add frequency penalty (#1541)

This commit is contained in:
OlivierDehaene 2024-02-08 18:41:25 +01:00 committed by GitHub
parent 39af000cb9
commit 09b7c26bbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 396 additions and 89 deletions

38
Cargo.lock generated
View File

@ -2787,7 +2787,7 @@ dependencies = [
"tabled", "tabled",
"text-generation-client", "text-generation-client",
"thiserror", "thiserror",
"tokenizers", "tokenizers 0.14.1",
"tokio", "tokio",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
@ -2850,7 +2850,7 @@ dependencies = [
"serde_json", "serde_json",
"text-generation-client", "text-generation-client",
"thiserror", "thiserror",
"tokenizers", "tokenizers 0.15.1",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tower-http", "tower-http",
@ -2972,6 +2972,40 @@ dependencies = [
"unicode_categories", "unicode_categories",
] ]
[[package]]
name = "tokenizers"
version = "0.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6db445cceba5dfeb0f9702be7d6bfd91801ddcbe8fe8722defe7f2e96da75812"
dependencies = [
"aho-corasick",
"clap",
"derive_builder",
"esaxx-rs",
"getrandom",
"hf-hub",
"indicatif",
"itertools 0.11.0",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand",
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.7.5",
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.35.1" version = "1.35.1"

View File

@ -30,6 +30,7 @@ pub async fn run(
top_p: Option<f32>, top_p: Option<f32>,
typical_p: Option<f32>, typical_p: Option<f32>,
repetition_penalty: Option<f32>, repetition_penalty: Option<f32>,
frequency_penalty: Option<f32>,
watermark: bool, watermark: bool,
do_sample: bool, do_sample: bool,
client: ShardedClient, client: ShardedClient,
@ -42,6 +43,7 @@ pub async fn run(
do_sample, do_sample,
seed: 0, seed: 0,
repetition_penalty: repetition_penalty.unwrap_or(1.0), repetition_penalty: repetition_penalty.unwrap_or(1.0),
frequency_penalty: frequency_penalty.unwrap_or(0.0),
watermark, watermark,
}; };
@ -140,6 +142,7 @@ pub async fn run(
top_p, top_p,
typical_p, typical_p,
repetition_penalty, repetition_penalty,
frequency_penalty,
watermark, watermark,
do_sample, do_sample,
); );

View File

@ -84,6 +84,11 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
repetition_penalty: Option<f32>, repetition_penalty: Option<f32>,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
frequency_penalty: Option<f32>,
/// Generation parameter in case you want to specifically test/debug particular /// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server` /// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)] #[clap(long, env)]
@ -119,6 +124,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
top_p, top_p,
typical_p, typical_p,
repetition_penalty, repetition_penalty,
frequency_penalty,
watermark, watermark,
do_sample, do_sample,
master_shard_uds_path, master_shard_uds_path,
@ -187,6 +193,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
top_p, top_p,
typical_p, typical_p,
repetition_penalty, repetition_penalty,
frequency_penalty,
watermark, watermark,
do_sample, do_sample,
sharded_client, sharded_client,

View File

@ -15,6 +15,7 @@ pub(crate) fn parameters_table(
top_p: Option<f32>, top_p: Option<f32>,
typical_p: Option<f32>, typical_p: Option<f32>,
repetition_penalty: Option<f32>, repetition_penalty: Option<f32>,
frequency_penalty: Option<f32>,
watermark: bool, watermark: bool,
do_sample: bool, do_sample: bool,
) -> Table { ) -> Table {
@ -33,6 +34,7 @@ pub(crate) fn parameters_table(
builder.push_record(["Top P", &format!("{top_p:?}")]); builder.push_record(["Top P", &format!("{top_p:?}")]);
builder.push_record(["Typical P", &format!("{typical_p:?}")]); builder.push_record(["Typical P", &format!("{typical_p:?}")]);
builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]); builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]);
builder.push_record(["Frequency Penalty", &format!("{frequency_penalty:?}")]);
builder.push_record(["Watermark", &watermark.to_string()]); builder.push_record(["Watermark", &watermark.to_string()]);
builder.push_record(["Do Sample", &do_sample.to_string()]); builder.push_record(["Do Sample", &do_sample.to_string()]);

View File

@ -24,6 +24,7 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
assert response.generated_text == "\n\nDeep learning is a new type of machine" assert response.generated_text == "\n\nDeep learning is a new type of machine"
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
@ -44,13 +45,19 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in" assert (
response.generated_text
== "blue, red, yellow, \nand orange (in the order they appear in"
)
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot): async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
responses = await generate_load(fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4) responses = await generate_load(
fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4
)
assert len(responses) == 4 assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]) assert all([r.generated_text == responses[0].generated_text for r in responses])

View File

@ -66,6 +66,8 @@ message NextTokenChooserParameters {
uint64 seed = 6; uint64 seed = 6;
/// repetition penalty /// repetition penalty
float repetition_penalty = 7; float repetition_penalty = 7;
/// frequency penalty
float frequency_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models" /// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8; bool watermark = 8;
} }

View File

@ -125,6 +125,7 @@ impl Client {
do_sample: false, do_sample: false,
seed: 0, seed: 0,
repetition_penalty: 1.2, repetition_penalty: 1.2,
frequency_penalty: 0.1,
watermark: true, watermark: true,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {

View File

@ -43,6 +43,7 @@ impl Health {
do_sample: false, do_sample: false,
seed: 0, seed: 0,
repetition_penalty: 1.0, repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false, watermark: false,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {

View File

@ -106,6 +106,14 @@ pub(crate) struct GenerateParameters {
)] )]
pub repetition_penalty: Option<f32>, pub repetition_penalty: Option<f32>,
#[serde(default)] #[serde(default)]
#[schema(
exclusive_minimum = -2.0,
nullable = true,
default = "null",
example = 0.1
)]
pub frequency_penalty: Option<f32>,
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
pub top_k: Option<i32>, pub top_k: Option<i32>,
#[serde(default)] #[serde(default)]
@ -172,6 +180,7 @@ fn default_parameters() -> GenerateParameters {
best_of: None, best_of: None,
temperature: None, temperature: None,
repetition_penalty: None, repetition_penalty: None,
frequency_penalty: None,
top_k: None, top_k: None,
top_p: None, top_p: None,
typical_p: None, typical_p: None,
@ -205,10 +214,71 @@ pub(crate) struct ChatCompletion {
pub(crate) struct ChatCompletionComplete { pub(crate) struct ChatCompletionComplete {
pub index: u32, pub index: u32,
pub message: Message, pub message: Message,
pub logprobs: Option<Vec<f32>>, pub logprobs: Option<ChatCompletionLogprobs>,
pub finish_reason: String, pub finish_reason: String,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionLogprobs {
content: Vec<ChatCompletionLogprob>,
}
impl From<(Token, Vec<Token>)> for ChatCompletionLogprobs {
fn from(value: (Token, Vec<Token>)) -> Self {
let (token, top_tokens) = value;
Self {
content: vec![ChatCompletionLogprob {
token: token.text,
logprob: token.logprob,
top_logprobs: top_tokens
.into_iter()
.map(|t| ChatCompletionTopLogprob {
token: t.text,
logprob: t.logprob,
})
.collect(),
}],
}
}
}
impl From<(Vec<Token>, Vec<Vec<Token>>)> for ChatCompletionLogprobs {
fn from(value: (Vec<Token>, Vec<Vec<Token>>)) -> Self {
let (tokens, top_tokens) = value;
Self {
content: tokens
.into_iter()
.zip(top_tokens)
.map(|(t, top_t)| ChatCompletionLogprob {
token: t.text,
logprob: t.logprob,
top_logprobs: top_t
.into_iter()
.map(|t| ChatCompletionTopLogprob {
token: t.text,
logprob: t.logprob,
})
.collect(),
})
.collect(),
}
}
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionLogprob {
token: String,
logprob: f32,
top_logprobs: Vec<ChatCompletionTopLogprob>,
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionTopLogprob {
token: String,
logprob: f32,
}
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize)]
pub(crate) struct Usage { pub(crate) struct Usage {
pub prompt_tokens: u32, pub prompt_tokens: u32,
@ -238,7 +308,7 @@ impl ChatCompletion {
content: output, content: output,
}, },
logprobs: return_logprobs logprobs: return_logprobs
.then(|| details.tokens.iter().map(|t| t.logprob).collect()), .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
finish_reason: details.finish_reason.to_string(), finish_reason: details.finish_reason.to_string(),
}], }],
usage: Usage { usage: Usage {
@ -266,7 +336,7 @@ pub(crate) struct ChatCompletionChunk {
pub(crate) struct ChatCompletionChoice { pub(crate) struct ChatCompletionChoice {
pub index: u32, pub index: u32,
pub delta: ChatCompletionDelta, pub delta: ChatCompletionDelta,
pub logprobs: Option<f32>, pub logprobs: Option<ChatCompletionLogprobs>,
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
} }
@ -285,7 +355,7 @@ impl ChatCompletionChunk {
delta: String, delta: String,
created: u64, created: u64,
index: u32, index: u32,
logprobs: Option<f32>, logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>, finish_reason: Option<String>,
) -> Self { ) -> Self {
Self { Self {
@ -319,8 +389,8 @@ pub(crate) struct ChatRequest {
/// UNUSED /// UNUSED
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String, /* NOTE: UNUSED */ pub model: String,
/* NOTE: UNUSED */
/// A list of messages comprising the conversation so far. /// A list of messages comprising the conversation so far.
#[serde(default = "default_request_messages")] #[serde(default = "default_request_messages")]
pub messages: Vec<Message>, pub messages: Vec<Message>,
@ -346,7 +416,6 @@ pub(crate) struct ChatRequest {
#[schema(example = "false")] #[schema(example = "false")]
pub logprobs: Option<bool>, pub logprobs: Option<bool>,
/// UNUSED
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with /// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
/// an associated log probability. logprobs must be set to true if this parameter is used. /// an associated log probability. logprobs must be set to true if this parameter is used.
#[serde(default)] #[serde(default)]
@ -365,7 +434,6 @@ pub(crate) struct ChatRequest {
#[schema(nullable = true, example = "2")] #[schema(nullable = true, example = "2")]
pub n: Option<u32>, pub n: Option<u32>,
/// UNUSED
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics /// increasing the model's likelihood to talk about new topics
#[serde(default)] #[serde(default)]
@ -447,7 +515,7 @@ pub struct PrefillToken {
logprob: f32, logprob: f32,
} }
#[derive(Debug, Serialize, ToSchema)] #[derive(Debug, Serialize, ToSchema, Clone)]
pub struct Token { pub struct Token {
#[schema(example = 0)] #[schema(example = 0)]
id: u32, id: u32,

View File

@ -355,6 +355,7 @@ mod tests {
do_sample: false, do_sample: false,
seed: 0, seed: 0,
repetition_penalty: 0.0, repetition_penalty: 0.0,
frequency_penalty: 0.0,
watermark: false, watermark: false,
}, },
stopping_parameters: StoppingCriteriaParameters { stopping_parameters: StoppingCriteriaParameters {

View File

@ -4,9 +4,10 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta, BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
ChatRequest, CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters, ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
GenerateRequest, GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Validation, HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
StreamResponse, Token, TokenizeResponse, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
@ -570,8 +571,8 @@ async fn chat_completions(
let stream = req.stream; let stream = req.stream;
let max_new_tokens = req.max_tokens.or(Some(100)); let max_new_tokens = req.max_tokens.or(Some(100));
let repetition_penalty = req let repetition_penalty = req
.frequency_penalty .presence_penalty
// rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0) // rescale repetition_penalty from (-2.0, 2.0) to (0.0, 4.0)
.map(|x| x + 2.0); .map(|x| x + 2.0);
let logprobs = req.logprobs.unwrap_or(false); let logprobs = req.logprobs.unwrap_or(false);
let seed = req.seed; let seed = req.seed;
@ -599,6 +600,7 @@ async fn chat_completions(
best_of: None, best_of: None,
temperature: req.temperature, temperature: req.temperature,
repetition_penalty, repetition_penalty,
frequency_penalty: req.frequency_penalty,
top_k: None, top_k: None,
top_p: req.top_p, top_p: req.top_p,
typical_p: None, typical_p: None,
@ -630,6 +632,10 @@ async fn chat_completions(
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) .unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs(); .as_secs();
let logprobs = logprobs.then(|| {
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
});
event event
.json_data(ChatCompletionChunk::new( .json_data(ChatCompletionChunk::new(
model_id.clone(), model_id.clone(),
@ -637,7 +643,7 @@ async fn chat_completions(
stream_token.token.text, stream_token.token.text,
current_time, current_time,
stream_token.index, stream_token.index,
logprobs.then_some(stream_token.token.logprob), logprobs,
stream_token.details.map(|d| d.finish_reason.to_string()), stream_token.details.map(|d| d.finish_reason.to_string()),
)) ))
.map_or_else( .map_or_else(

View File

@ -170,6 +170,7 @@ impl Validation {
best_of, best_of,
temperature, temperature,
repetition_penalty, repetition_penalty,
frequency_penalty,
top_k, top_k,
top_p, top_p,
typical_p, typical_p,
@ -206,6 +207,11 @@ impl Validation {
return Err(ValidationError::RepetitionPenalty); return Err(ValidationError::RepetitionPenalty);
} }
let frequency_penalty = frequency_penalty.unwrap_or(0.0);
if !(-2.0..=2.0).contains(&frequency_penalty) {
return Err(ValidationError::FrequencyPenalty);
}
// Different because the proto default value is not a valid value // Different because the proto default value is not a valid value
// for the user // for the user
let top_p = top_p let top_p = top_p
@ -289,6 +295,7 @@ impl Validation {
let parameters = NextTokenChooserParameters { let parameters = NextTokenChooserParameters {
temperature, temperature,
repetition_penalty, repetition_penalty,
frequency_penalty,
top_k, top_k,
top_p, top_p,
typical_p, typical_p,
@ -420,6 +427,8 @@ pub enum ValidationError {
Temperature, Temperature,
#[error("`repetition_penalty` must be strictly positive")] #[error("`repetition_penalty` must be strictly positive")]
RepetitionPenalty, RepetitionPenalty,
#[error("`frequency_penalty` must be >= -2.0 and <= 2.0")]
FrequencyPenalty,
#[error("`top_p` must be > 0.0 and < 1.0")] #[error("`top_p` must be > 0.0 and < 1.0")]
TopP, TopP,
#[error("`top_k` must be strictly positive")] #[error("`top_k` must be strictly positive")]

View File

@ -86,6 +86,7 @@ except ImportError as e:
if MAMBA_AVAILABLE: if MAMBA_AVAILABLE:
__all__.append(Mamba) __all__.append(Mamba)
def get_model( def get_model(
model_id: str, model_id: str,
revision: Optional[str], revision: Optional[str],

View File

@ -696,14 +696,17 @@ class CausalLM(Model):
if top_n_tokens > 0: if top_n_tokens > 0:
all_top_tokens = [] all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs): for (top_token_ids, top_token_logprobs) in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode( toptoken_texts = self.tokenizer.batch_decode(
top_token_ids, top_token_ids,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
special_toptokens = [ special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids token_id in self.all_special_ids
for token_id in top_token_ids
] ]
top_tokens = Tokens( top_tokens = Tokens(
top_token_ids, top_token_ids,

View File

@ -19,6 +19,7 @@ from einops import rearrange
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
import math import math
class MambaConfig(PretrainedConfig): class MambaConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
@ -53,6 +54,7 @@ class MambaConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
class MambaBlock(nn.Module): class MambaBlock(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
@ -60,8 +62,12 @@ class MambaBlock(nn.Module):
self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False) self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False)
self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False) self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False)
self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True) self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True)
self.dt_proj_no_bias = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=False) self.dt_proj_no_bias = FastLinear.load(
self.out_proj = FastLinear.load(config, f"{prefix}.out_proj", weights, bias=False) config, f"{prefix}.dt_proj", weights, bias=False
)
self.out_proj = FastLinear.load(
config, f"{prefix}.out_proj", weights, bias=False
)
self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True) self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True)
self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float()) self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float())
self.D = weights.get_tensor(f"{prefix}.D") self.D = weights.get_tensor(f"{prefix}.D")
@ -80,12 +86,14 @@ class MambaBlock(nn.Module):
out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state) out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state)
return out, conv_state, ssm_state return out, conv_state, ssm_state
projected_states = self.in_proj(hidden_states).transpose(1,2) projected_states = self.in_proj(hidden_states).transpose(1, 2)
x, z = projected_states.chunk(2, dim=1) x, z = projected_states.chunk(2, dim=1)
conv_state = F.pad(x, (self.d_conv - seqlen, 0)) conv_state = F.pad(x, (self.d_conv - seqlen, 0))
x = causal_conv1d_fn( x = causal_conv1d_fn(
x=x, x=x,
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), weight=self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
),
bias=self.conv1d.bias, bias=self.conv1d.bias,
activation=self.activation, activation=self.activation,
) )
@ -94,7 +102,9 @@ class MambaBlock(nn.Module):
# We want dt to have d as the slowest moving dimension # We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt, B, C = torch.split(
x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
dt = self.dt_proj.weight @ dt.t() dt = self.dt_proj.weight @ dt.t()
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
@ -118,28 +128,39 @@ class MambaBlock(nn.Module):
def step(self, hidden_states, conv_state, ssm_state): def step(self, hidden_states, conv_state, ssm_state):
_xz = self.in_proj(hidden_states) _xz = self.in_proj(hidden_states)
_x, _z = _xz.chunk(2, dim=-1) # (B D) _x, _z = _xz.chunk(2, dim=-1) # (B D)
conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1) conv_state_new = torch.cat([conv_state, _x.transpose(1, 2)], dim=-1)
conv_out = causal_conv1d_fn( conv_out = causal_conv1d_fn(
x=conv_state_new, x=conv_state_new,
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), weight=self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
),
bias=self.conv1d.bias, bias=self.conv1d.bias,
activation=self.activation activation=self.activation,
) )
conv_state = conv_state_new[:, :, 1:] conv_state = conv_state_new[:, :, 1:]
bsz, seqlen, dim = hidden_states.shape bsz, seqlen, dim = hidden_states.shape
output_tensor = torch.zeros( output_tensor = torch.zeros(
(bsz, seqlen, dim), (bsz, seqlen, dim), device=hidden_states.device, dtype=hidden_states.dtype
device=hidden_states.device,
dtype=hidden_states.dtype
) )
for i in range(0, bsz): for i in range(0, bsz):
x = conv_out[i:i+1,:,-1] x = conv_out[i : i + 1, :, -1]
z = _z[i:i+1, -1, :] z = _z[i : i + 1, -1, :]
x_db = self.x_proj(x) x_db = self.x_proj(x)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt, B, C = torch.split(
x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
dt = F.linear(dt, self.dt_proj.weight) dt = F.linear(dt, self.dt_proj.weight)
y = selective_state_update( y = selective_state_update(
ssm_state[i:i+1,:,:], x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True ssm_state[i : i + 1, :, :],
x,
dt,
self.negA,
B,
C,
self.D,
z=z,
dt_bias=self.dt_proj.bias,
dt_softplus=True,
) )
out = self.out_proj(y) out = self.out_proj(y)
output_tensor[i] = out output_tensor[i] = out
@ -147,12 +168,15 @@ class MambaBlock(nn.Module):
return output_tensor, conv_state, ssm_state return output_tensor, conv_state, ssm_state
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, layer_id, config, weights):
super().__init__() super().__init__()
self.mamba_block = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights) self.mamba_block = MambaBlock(
self.layer_norm = FastRMSNorm.load(prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon) prefix=f"{layer_id}.mixer", config=config, weights=weights
)
self.layer_norm = FastRMSNorm.load(
prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon
)
def forward( def forward(
self, self,
@ -163,28 +187,47 @@ class ResidualBlock(nn.Module):
residual = (hidden_states + residual) if residual is not None else hidden_states residual = (hidden_states + residual) if residual is not None else hidden_states
shape = residual.shape shape = residual.shape
hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1])) hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1]))
hidden_states, conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params) hidden_states, conv_state, last_ssm_state = self.mamba_block(
hidden_states.view(*shape), inference_params
)
return hidden_states, residual, conv_state, last_ssm_state return hidden_states, residual, conv_state, last_ssm_state
class MambaModel(nn.Module): class MambaModel(nn.Module):
def __init__(self, config, weights): def __init__(self, config, weights):
super().__init__() super().__init__()
prefix = "backbone" prefix = "backbone"
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)] [
ResidualBlock(f"{prefix}.layers.{i}", config, weights)
for i in range(config.n_layer)
]
)
self.norm_f = FastRMSNorm.load(
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
)
self.lm_head = FastLinear.load(
config, f"{prefix}.embedding", weights, bias=False
) )
self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon)
self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False)
self.config = config self.config = config
def forward(self, input_ids: torch.Tensor, inference_params=None, residual=None) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]: def forward(
self, input_ids: torch.Tensor, inference_params=None, residual=None
) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
for block in self.blocks: for block in self.blocks:
hidden_states, residual, conv_state, ssm_state = block(hidden_states, residual, inference_params) hidden_states, residual, conv_state, ssm_state = block(
inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (conv_state, ssm_state) hidden_states, residual, inference_params
)
inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (
conv_state,
ssm_state,
)
hidden_states = hidden_states + residual if residual is not None else hidden_states hidden_states = (
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1))) hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
hidden_states = hidden_states.view(residual.shape) hidden_states = hidden_states.view(residual.shape)
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)

View File

@ -842,7 +842,6 @@ class FlashCausalLM(Model):
else: else:
next_token_logits = out next_token_logits = out
speculate = get_speculate() speculate = get_speculate()
( (
next_input_ids, next_input_ids,
@ -1064,14 +1063,17 @@ class FlashCausalLM(Model):
if top_n_tokens > 0: if top_n_tokens > 0:
all_top_tokens = [] all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs): for (top_token_ids, top_token_logprobs) in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode( toptoken_texts = self.tokenizer.batch_decode(
top_token_ids, top_token_ids,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
special_toptokens = [ special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids token_id in self.all_special_ids
for token_id in top_token_ids
] ]
top_tokens = Tokens( top_tokens = Tokens(
top_token_ids, top_token_ids,

View File

@ -26,6 +26,7 @@ from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from mamba_ssm.utils.generation import InferenceParams from mamba_ssm.utils.generation import InferenceParams
@dataclass @dataclass
class MambaBatch(Batch): class MambaBatch(Batch):
batch_id: int batch_id: int
@ -221,7 +222,10 @@ class MambaBatch(Batch):
# TODO # TODO
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary. # Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
key_value_memory_dict = {} key_value_memory_dict = {}
for i, (conv_state, ssm_state) in self.inference_params.key_value_memory_dict.items(): for i, (
conv_state,
ssm_state,
) in self.inference_params.key_value_memory_dict.items():
key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices]) key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices])
self.inference_params.key_value_memory_dict = key_value_memory_dict self.inference_params.key_value_memory_dict = key_value_memory_dict
@ -305,8 +309,9 @@ class MambaBatch(Batch):
start_index = end_index start_index = end_index
(_, d_model, d_conv) = (
(_, d_model, d_conv) = batches[0].inference_params.key_value_memory_dict[0][0].shape batches[0].inference_params.key_value_memory_dict[0][0].shape
)
(_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape (_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape
n_blocks = len(batches[0].inference_params.key_value_memory_dict) n_blocks = len(batches[0].inference_params.key_value_memory_dict)
dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype
@ -344,9 +349,15 @@ class MambaBatch(Batch):
for i in range(n_blocks): for i in range(n_blocks):
conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i] conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i]
batch_size = batch.inference_params.max_batch_size batch_size = batch.inference_params.max_batch_size
inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state inference_params.key_value_memory_dict[i][0][
inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state current_batch : current_batch + batch_size
inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample ] = conv_state
inference_params.key_value_memory_dict[i][1][
current_batch : current_batch + batch_size
] = ssm_state
inference_params.lengths_per_sample[
current_batch : current_batch + batch_size
] = batch.inference_params.lengths_per_sample
current_batch += batch_size current_batch += batch_size
return cls( return cls(
@ -366,12 +377,13 @@ class MambaBatch(Batch):
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last, keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens, max_tokens=max_tokens,
inference_params=inference_params inference_params=inference_params,
) )
def __len__(self): def __len__(self):
return len(self.requests) return len(self.requests)
class Mamba(Model): class Mamba(Model):
def __init__( def __init__(
self, self,
@ -441,7 +453,9 @@ class Mamba(Model):
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
start = time.time_ns() start = time.time_ns()
input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids input_ids = (
batch.input_ids
) # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
batch_size = input_ids.shape[0] batch_size = input_ids.shape[0]
max_seqlen = input_ids.shape[1] max_seqlen = input_ids.shape[1]
@ -450,7 +464,10 @@ class Mamba(Model):
# Inference params # Inference params
seqlen_og = 0 seqlen_og = 0
inf_cache = {} inf_cache = {}
lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen lengths_per_sample = (
torch.ones(batch_size, dtype=torch.int32, device=input_ids.device)
* max_seqlen
)
if batch.inference_params is None: if batch.inference_params is None:
inference_params = InferenceParams( inference_params = InferenceParams(
@ -478,11 +495,16 @@ class Mamba(Model):
device=block.dt_proj.weight.device, device=block.dt_proj.weight.device,
dtype=block.dt_proj.weight.dtype, dtype=block.dt_proj.weight.dtype,
) )
inference_params.key_value_memory_dict[block.layer_idx] = (conv_state, ssm_state) inference_params.key_value_memory_dict[block.layer_idx] = (
conv_state,
ssm_state,
)
batch.inference_params = inference_params batch.inference_params = inference_params
# Forward pass # Forward pass
logits, past_input_ids, new_inference_params = self.model(input_ids, batch.inference_params) logits, past_input_ids, new_inference_params = self.model(
input_ids, batch.inference_params
)
batch.inference_params = new_inference_params batch.inference_params = new_inference_params
# Results # Results
@ -564,7 +586,8 @@ class Mamba(Model):
prefix_offset=len(all_input_ids) prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens - stopping_criteria.current_tokens
- 1, - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens, read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True, skip_special_tokens=True,
) )
# Get seed # Get seed

View File

@ -750,14 +750,17 @@ class Seq2SeqLM(Model):
if top_n_tokens > 0: if top_n_tokens > 0:
all_top_tokens = [] all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs): for (top_token_ids, top_token_logprobs) in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode( toptoken_texts = self.tokenizer.batch_decode(
top_token_ids, top_token_ids,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
special_toptokens = [ special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids token_id in self.all_special_ids
for token_id in top_token_ids
] ]
top_tokens = Tokens( top_tokens = Tokens(
top_token_ids, top_token_ids,

View File

@ -95,5 +95,7 @@ class Generation:
generated_text=self.generated_text.to_pb() generated_text=self.generated_text.to_pb()
if self.generated_text is not None if self.generated_text is not None
else None, else None,
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens] if self.top_tokens is not None else None, top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens]
if self.top_tokens is not None
else None,
) )

View File

@ -118,6 +118,62 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
return None return None
class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
r"""
Frequency penalty as defined by OpenAI
Args:
penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty.
"""
def __init__(self, penalty: float):
self.penalty = penalty
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where(
score < 0, score * self.penalty, score / self.penalty
)
return scores.scatter_add_(1, input_ids, score)
class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
r"""
Frequency penalty as defined by OpenAI
Args:
frequency_penalty (`List[float]`):
The parameter for frequency penalty. 0.0 means no penalty.
"""
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
self.penalty = penalty
self.penalty_tensor = torch.tensor(
penalty, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where(
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
)
return scores.scatter_add_(1, input_ids, score)
def filter(self, indices):
self.penalty = [self.penalty[i] for i in indices]
if any([x != 0.0 for x in self.penalty]):
self.penalty_tensor = self.penalty_tensor[indices]
return self
return None
class HeterogeneousTemperatureLogitsWarper: class HeterogeneousTemperatureLogitsWarper:
r""" r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution). [`LogitsWarper`] for temperature (exponential scaling output probability distribution).

View File

@ -1,12 +1,14 @@
import re import re
from typing import Callable, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.logits_process import ( from text_generation_server.utils.logits_process import (
FrequencyPenaltyLogitsProcessor,
HeterogeneousProcessorWrapper, HeterogeneousProcessorWrapper,
HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousFrequencyPenaltyLogitsProcessor,
HeterogeneousTemperatureLogitsWarper, HeterogeneousTemperatureLogitsWarper,
HeterogeneousTopKLogitsWarper, HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper, HeterogeneousTopPLogitsWarper,
@ -23,6 +25,7 @@ class NextTokenChooser:
watermark=False, watermark=False,
temperature=1.0, temperature=1.0,
repetition_penalty=1.0, repetition_penalty=1.0,
frequency_penalty=0.0,
top_k=None, top_k=None,
top_p=None, top_p=None,
typical_p=None, typical_p=None,
@ -35,7 +38,12 @@ class NextTokenChooser:
) )
self.repetition_processor = ( self.repetition_processor = (
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
if repetition_penalty if repetition_penalty and repetition_penalty != 1.0
else None
)
self.frequency_processor = (
FrequencyPenaltyLogitsProcessor(penalty=frequency_penalty)
if frequency_penalty and frequency_penalty != 0.0
else None else None
) )
@ -60,6 +68,8 @@ class NextTokenChooser:
scores = self.watermark_processor(input_ids, scores) scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None: if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores) scores = self.repetition_processor(input_ids, scores)
if self.frequency_processor is not None:
scores = self.frequency_processor(input_ids, scores)
if self.static_warper is None: if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1) next_logprob = torch.log_softmax(scores, -1)
@ -80,6 +90,7 @@ class NextTokenChooser:
watermark=pb.watermark, watermark=pb.watermark,
temperature=pb.temperature, temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty, repetition_penalty=pb.repetition_penalty,
frequency_penalty=pb.frequency_penalty,
top_k=pb.top_k, top_k=pb.top_k,
top_p=pb.top_p, top_p=pb.top_p,
typical_p=pb.typical_p, typical_p=pb.typical_p,
@ -184,6 +195,7 @@ class HeterogeneousNextTokenChooser:
watermark: List[bool], watermark: List[bool],
temperature: List[float], temperature: List[float],
repetition_penalty: List[float], repetition_penalty: List[float],
frequency_penalty: List[float],
top_k: List[int], top_k: List[int],
top_p: List[float], top_p: List[float],
typical_p: List[float], typical_p: List[float],
@ -212,6 +224,14 @@ class HeterogeneousNextTokenChooser:
else None else None
) )
self.frequency_processor = (
HeterogeneousFrequencyPenaltyLogitsProcessor(
frequency_penalty, dtype, device
)
if any([x != 0.0 for x in frequency_penalty])
else None
)
if any([x != 1.0 for x in temperature]): if any([x != 1.0 for x in temperature]):
do_sample = [ do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample) sample or x != 1.0 for x, sample in zip(temperature, do_sample)
@ -269,6 +289,8 @@ class HeterogeneousNextTokenChooser:
_scores = self.watermark_processor(input_ids, _scores) _scores = self.watermark_processor(input_ids, _scores)
if self.repetition_processor is not None: if self.repetition_processor is not None:
_scores = self.repetition_processor(input_ids, _scores) _scores = self.repetition_processor(input_ids, _scores)
if self.frequency_processor is not None:
_scores = self.frequency_processor(input_ids, _scores)
for warper in self.warpers: for warper in self.warpers:
_scores = warper(input_ids, _scores) _scores = warper(input_ids, _scores)
@ -316,7 +338,6 @@ class HeterogeneousNextTokenChooser:
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
if speculate > 0: if speculate > 0:
if speculative_scores is not None: if speculative_scores is not None:
# Medusa provided some scores # Medusa provided some scores
@ -338,6 +359,9 @@ class HeterogeneousNextTokenChooser:
if self.repetition_processor is not None: if self.repetition_processor is not None:
self.repetition_processor = self.repetition_processor.filter(indices) self.repetition_processor = self.repetition_processor.filter(indices)
if self.frequency_processor is not None:
self.frequency_processor = self.frequency_processor.filter(indices)
filtered_warpers = [] filtered_warpers = []
for warper in self.warpers: for warper in self.warpers:
filtered_warper = warper.filter(indices) filtered_warper = warper.filter(indices)
@ -366,6 +390,7 @@ class HeterogeneousNextTokenChooser:
watermark=[pb_.watermark for pb_ in pb], watermark=[pb_.watermark for pb_ in pb],
temperature=[pb_.temperature for pb_ in pb], temperature=[pb_.temperature for pb_ in pb],
repetition_penalty=[pb_.repetition_penalty for pb_ in pb], repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
top_k=[pb_.top_k for pb_ in pb], top_k=[pb_.top_k for pb_ in pb],
top_p=[pb_.top_p for pb_ in pb], top_p=[pb_.top_p for pb_ in pb],
typical_p=[pb_.typical_p for pb_ in pb], typical_p=[pb_.typical_p for pb_ in pb],
@ -438,7 +463,10 @@ class HeterogeneousSampling:
def batch_top_tokens( def batch_top_tokens(
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor top_n_tokens: List[int],
top_n_tokens_tensor: torch.Tensor,
logprobs: torch.Tensor,
accepted_ids: torch.Tensor,
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]: ) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
"""Find the top n most likely tokens for a batch of generations. """Find the top n most likely tokens for a batch of generations.
@ -450,12 +478,15 @@ def batch_top_tokens(
if max_top_n == 0: if max_top_n == 0:
return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens) return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
batch_size = accepted_ids.shape[0] batch_size = accepted_ids.shape[0]
speculate_size = logprobs.shape[0] // batch_size speculate_size = logprobs.shape[0] // batch_size
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size) top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
# Ensure top_n doesn't exceed vocab size # Ensure top_n doesn't exceed vocab size
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculate_size)] top_n_tokens = [
min(tok, logprobs.size(-1))
for tok in top_n_tokens
for _ in range(speculate_size)
]
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2 # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset # Sorted topk is faster than torch.sort() since we only need a small subset
@ -484,10 +515,10 @@ def batch_top_tokens(
for i, n_accepted_ids in enumerate(accepted_ids_list): for i, n_accepted_ids in enumerate(accepted_ids_list):
start = speculate_size * i start = speculate_size * i
stop = speculate_size * (i + 1) stop = speculate_size * (i + 1)
_top_indices = top_indices[start: stop] _top_indices = top_indices[start:stop]
_top_values = top_values[start: stop] _top_values = top_values[start:stop]
_top_n_ishes = top_n_ishes[start: stop] _top_n_ishes = top_n_ishes[start:stop]
_top_n_tokens = top_n_tokens[start: stop] _top_n_tokens = top_n_tokens[start:stop]
_top_indices = _top_indices[:n_accepted_ids] _top_indices = _top_indices[:n_accepted_ids]
_top_values = _top_values[:n_accepted_ids] _top_values = _top_values[:n_accepted_ids]
@ -497,7 +528,9 @@ def batch_top_tokens(
row_top_token_ids = [] row_top_token_ids = []
row_top_token_logprobs = [] row_top_token_logprobs = []
for idxs, vals, n, req_n in zip(_top_indices, _top_values, _top_n_ishes, _top_n_tokens): for idxs, vals, n, req_n in zip(
_top_indices, _top_values, _top_n_ishes, _top_n_tokens
):
indices = idxs[:n] if req_n > 0 else [] indices = idxs[:n] if req_n > 0 else []
values = vals[:n] if req_n > 0 else [] values = vals[:n] if req_n > 0 else []