feat(server): add frequency penalty (#1541)
This commit is contained in:
parent
39af000cb9
commit
09b7c26bbd
|
@ -2787,7 +2787,7 @@ dependencies = [
|
|||
"tabled",
|
||||
"text-generation-client",
|
||||
"thiserror",
|
||||
"tokenizers",
|
||||
"tokenizers 0.14.1",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
|
@ -2850,7 +2850,7 @@ dependencies = [
|
|||
"serde_json",
|
||||
"text-generation-client",
|
||||
"thiserror",
|
||||
"tokenizers",
|
||||
"tokenizers 0.15.1",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tower-http",
|
||||
|
@ -2972,6 +2972,40 @@ dependencies = [
|
|||
"unicode_categories",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokenizers"
|
||||
version = "0.15.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6db445cceba5dfeb0f9702be7d6bfd91801ddcbe8fe8722defe7f2e96da75812"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"clap",
|
||||
"derive_builder",
|
||||
"esaxx-rs",
|
||||
"getrandom",
|
||||
"hf-hub",
|
||||
"indicatif",
|
||||
"itertools 0.11.0",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"macro_rules_attribute",
|
||||
"monostate",
|
||||
"onig",
|
||||
"paste",
|
||||
"rand",
|
||||
"rayon",
|
||||
"rayon-cond",
|
||||
"regex",
|
||||
"regex-syntax 0.7.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"spm_precompiled",
|
||||
"thiserror",
|
||||
"unicode-normalization-alignments",
|
||||
"unicode-segmentation",
|
||||
"unicode_categories",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.35.1"
|
||||
|
|
|
@ -30,6 +30,7 @@ pub async fn run(
|
|||
top_p: Option<f32>,
|
||||
typical_p: Option<f32>,
|
||||
repetition_penalty: Option<f32>,
|
||||
frequency_penalty: Option<f32>,
|
||||
watermark: bool,
|
||||
do_sample: bool,
|
||||
client: ShardedClient,
|
||||
|
@ -42,6 +43,7 @@ pub async fn run(
|
|||
do_sample,
|
||||
seed: 0,
|
||||
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
||||
frequency_penalty: frequency_penalty.unwrap_or(0.0),
|
||||
watermark,
|
||||
};
|
||||
|
||||
|
@ -140,6 +142,7 @@ pub async fn run(
|
|||
top_p,
|
||||
typical_p,
|
||||
repetition_penalty,
|
||||
frequency_penalty,
|
||||
watermark,
|
||||
do_sample,
|
||||
);
|
||||
|
|
|
@ -84,6 +84,11 @@ struct Args {
|
|||
#[clap(long, env)]
|
||||
repetition_penalty: Option<f32>,
|
||||
|
||||
/// Generation parameter in case you want to specifically test/debug particular
|
||||
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||
#[clap(long, env)]
|
||||
frequency_penalty: Option<f32>,
|
||||
|
||||
/// Generation parameter in case you want to specifically test/debug particular
|
||||
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||
#[clap(long, env)]
|
||||
|
@ -119,6 +124,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
top_p,
|
||||
typical_p,
|
||||
repetition_penalty,
|
||||
frequency_penalty,
|
||||
watermark,
|
||||
do_sample,
|
||||
master_shard_uds_path,
|
||||
|
@ -187,6 +193,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
top_p,
|
||||
typical_p,
|
||||
repetition_penalty,
|
||||
frequency_penalty,
|
||||
watermark,
|
||||
do_sample,
|
||||
sharded_client,
|
||||
|
|
|
@ -15,6 +15,7 @@ pub(crate) fn parameters_table(
|
|||
top_p: Option<f32>,
|
||||
typical_p: Option<f32>,
|
||||
repetition_penalty: Option<f32>,
|
||||
frequency_penalty: Option<f32>,
|
||||
watermark: bool,
|
||||
do_sample: bool,
|
||||
) -> Table {
|
||||
|
@ -33,6 +34,7 @@ pub(crate) fn parameters_table(
|
|||
builder.push_record(["Top P", &format!("{top_p:?}")]);
|
||||
builder.push_record(["Typical P", &format!("{typical_p:?}")]);
|
||||
builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]);
|
||||
builder.push_record(["Frequency Penalty", &format!("{frequency_penalty:?}")]);
|
||||
builder.push_record(["Watermark", &watermark.to_string()]);
|
||||
builder.push_record(["Do Sample", &do_sample.to_string()]);
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
|
|||
assert response.generated_text == "\n\nDeep learning is a new type of machine"
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
||||
|
@ -44,13 +45,19 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
|||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in"
|
||||
assert (
|
||||
response.generated_text
|
||||
== "blue, red, yellow, \nand orange (in the order they appear in"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
|
||||
responses = await generate_load(fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4)
|
||||
responses = await generate_load(
|
||||
fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
|
|
@ -66,6 +66,8 @@ message NextTokenChooserParameters {
|
|||
uint64 seed = 6;
|
||||
/// repetition penalty
|
||||
float repetition_penalty = 7;
|
||||
/// frequency penalty
|
||||
float frequency_penalty = 9;
|
||||
/// token watermarking using "A Watermark for Large Language Models"
|
||||
bool watermark = 8;
|
||||
}
|
||||
|
|
|
@ -125,6 +125,7 @@ impl Client {
|
|||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.2,
|
||||
frequency_penalty: 0.1,
|
||||
watermark: true,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
|
|
|
@ -43,6 +43,7 @@ impl Health {
|
|||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.0,
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
|
|
|
@ -106,6 +106,14 @@ pub(crate) struct GenerateParameters {
|
|||
)]
|
||||
pub repetition_penalty: Option<f32>,
|
||||
#[serde(default)]
|
||||
#[schema(
|
||||
exclusive_minimum = -2.0,
|
||||
nullable = true,
|
||||
default = "null",
|
||||
example = 0.1
|
||||
)]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
#[serde(default)]
|
||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
|
||||
pub top_k: Option<i32>,
|
||||
#[serde(default)]
|
||||
|
@ -172,6 +180,7 @@ fn default_parameters() -> GenerateParameters {
|
|||
best_of: None,
|
||||
temperature: None,
|
||||
repetition_penalty: None,
|
||||
frequency_penalty: None,
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
typical_p: None,
|
||||
|
@ -205,10 +214,71 @@ pub(crate) struct ChatCompletion {
|
|||
pub(crate) struct ChatCompletionComplete {
|
||||
pub index: u32,
|
||||
pub message: Message,
|
||||
pub logprobs: Option<Vec<f32>>,
|
||||
pub logprobs: Option<ChatCompletionLogprobs>,
|
||||
pub finish_reason: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct ChatCompletionLogprobs {
|
||||
content: Vec<ChatCompletionLogprob>,
|
||||
}
|
||||
|
||||
impl From<(Token, Vec<Token>)> for ChatCompletionLogprobs {
|
||||
fn from(value: (Token, Vec<Token>)) -> Self {
|
||||
let (token, top_tokens) = value;
|
||||
|
||||
Self {
|
||||
content: vec![ChatCompletionLogprob {
|
||||
token: token.text,
|
||||
logprob: token.logprob,
|
||||
top_logprobs: top_tokens
|
||||
.into_iter()
|
||||
.map(|t| ChatCompletionTopLogprob {
|
||||
token: t.text,
|
||||
logprob: t.logprob,
|
||||
})
|
||||
.collect(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(Vec<Token>, Vec<Vec<Token>>)> for ChatCompletionLogprobs {
|
||||
fn from(value: (Vec<Token>, Vec<Vec<Token>>)) -> Self {
|
||||
let (tokens, top_tokens) = value;
|
||||
Self {
|
||||
content: tokens
|
||||
.into_iter()
|
||||
.zip(top_tokens)
|
||||
.map(|(t, top_t)| ChatCompletionLogprob {
|
||||
token: t.text,
|
||||
logprob: t.logprob,
|
||||
top_logprobs: top_t
|
||||
.into_iter()
|
||||
.map(|t| ChatCompletionTopLogprob {
|
||||
token: t.text,
|
||||
logprob: t.logprob,
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct ChatCompletionLogprob {
|
||||
token: String,
|
||||
logprob: f32,
|
||||
top_logprobs: Vec<ChatCompletionTopLogprob>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct ChatCompletionTopLogprob {
|
||||
token: String,
|
||||
logprob: f32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
pub(crate) struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
|
@ -238,7 +308,7 @@ impl ChatCompletion {
|
|||
content: output,
|
||||
},
|
||||
logprobs: return_logprobs
|
||||
.then(|| details.tokens.iter().map(|t| t.logprob).collect()),
|
||||
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
|
||||
finish_reason: details.finish_reason.to_string(),
|
||||
}],
|
||||
usage: Usage {
|
||||
|
@ -266,7 +336,7 @@ pub(crate) struct ChatCompletionChunk {
|
|||
pub(crate) struct ChatCompletionChoice {
|
||||
pub index: u32,
|
||||
pub delta: ChatCompletionDelta,
|
||||
pub logprobs: Option<f32>,
|
||||
pub logprobs: Option<ChatCompletionLogprobs>,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
|
@ -285,7 +355,7 @@ impl ChatCompletionChunk {
|
|||
delta: String,
|
||||
created: u64,
|
||||
index: u32,
|
||||
logprobs: Option<f32>,
|
||||
logprobs: Option<ChatCompletionLogprobs>,
|
||||
finish_reason: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
|
@ -319,8 +389,8 @@ pub(crate) struct ChatRequest {
|
|||
/// UNUSED
|
||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||||
pub model: String, /* NOTE: UNUSED */
|
||||
|
||||
pub model: String,
|
||||
/* NOTE: UNUSED */
|
||||
/// A list of messages comprising the conversation so far.
|
||||
#[serde(default = "default_request_messages")]
|
||||
pub messages: Vec<Message>,
|
||||
|
@ -346,7 +416,6 @@ pub(crate) struct ChatRequest {
|
|||
#[schema(example = "false")]
|
||||
pub logprobs: Option<bool>,
|
||||
|
||||
/// UNUSED
|
||||
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
|
||||
/// an associated log probability. logprobs must be set to true if this parameter is used.
|
||||
#[serde(default)]
|
||||
|
@ -365,7 +434,6 @@ pub(crate) struct ChatRequest {
|
|||
#[schema(nullable = true, example = "2")]
|
||||
pub n: Option<u32>,
|
||||
|
||||
/// UNUSED
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
|
||||
/// increasing the model's likelihood to talk about new topics
|
||||
#[serde(default)]
|
||||
|
@ -447,7 +515,7 @@ pub struct PrefillToken {
|
|||
logprob: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
#[derive(Debug, Serialize, ToSchema, Clone)]
|
||||
pub struct Token {
|
||||
#[schema(example = 0)]
|
||||
id: u32,
|
||||
|
|
|
@ -355,6 +355,7 @@ mod tests {
|
|||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 0.0,
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
},
|
||||
stopping_parameters: StoppingCriteriaParameters {
|
||||
|
|
|
@ -4,9 +4,10 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
|||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
||||
ChatRequest, CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
|
||||
GenerateRequest, GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message,
|
||||
PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Validation,
|
||||
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
|
||||
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
|
||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||
StreamResponse, Token, TokenizeResponse, Validation,
|
||||
};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
|
@ -570,8 +571,8 @@ async fn chat_completions(
|
|||
let stream = req.stream;
|
||||
let max_new_tokens = req.max_tokens.or(Some(100));
|
||||
let repetition_penalty = req
|
||||
.frequency_penalty
|
||||
// rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0)
|
||||
.presence_penalty
|
||||
// rescale repetition_penalty from (-2.0, 2.0) to (0.0, 4.0)
|
||||
.map(|x| x + 2.0);
|
||||
let logprobs = req.logprobs.unwrap_or(false);
|
||||
let seed = req.seed;
|
||||
|
@ -599,6 +600,7 @@ async fn chat_completions(
|
|||
best_of: None,
|
||||
temperature: req.temperature,
|
||||
repetition_penalty,
|
||||
frequency_penalty: req.frequency_penalty,
|
||||
top_k: None,
|
||||
top_p: req.top_p,
|
||||
typical_p: None,
|
||||
|
@ -630,6 +632,10 @@ async fn chat_completions(
|
|||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let logprobs = logprobs.then(|| {
|
||||
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
|
||||
});
|
||||
|
||||
event
|
||||
.json_data(ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
|
@ -637,7 +643,7 @@ async fn chat_completions(
|
|||
stream_token.token.text,
|
||||
current_time,
|
||||
stream_token.index,
|
||||
logprobs.then_some(stream_token.token.logprob),
|
||||
logprobs,
|
||||
stream_token.details.map(|d| d.finish_reason.to_string()),
|
||||
))
|
||||
.map_or_else(
|
||||
|
|
|
@ -170,6 +170,7 @@ impl Validation {
|
|||
best_of,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
frequency_penalty,
|
||||
top_k,
|
||||
top_p,
|
||||
typical_p,
|
||||
|
@ -206,6 +207,11 @@ impl Validation {
|
|||
return Err(ValidationError::RepetitionPenalty);
|
||||
}
|
||||
|
||||
let frequency_penalty = frequency_penalty.unwrap_or(0.0);
|
||||
if !(-2.0..=2.0).contains(&frequency_penalty) {
|
||||
return Err(ValidationError::FrequencyPenalty);
|
||||
}
|
||||
|
||||
// Different because the proto default value is not a valid value
|
||||
// for the user
|
||||
let top_p = top_p
|
||||
|
@ -289,6 +295,7 @@ impl Validation {
|
|||
let parameters = NextTokenChooserParameters {
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
frequency_penalty,
|
||||
top_k,
|
||||
top_p,
|
||||
typical_p,
|
||||
|
@ -420,6 +427,8 @@ pub enum ValidationError {
|
|||
Temperature,
|
||||
#[error("`repetition_penalty` must be strictly positive")]
|
||||
RepetitionPenalty,
|
||||
#[error("`frequency_penalty` must be >= -2.0 and <= 2.0")]
|
||||
FrequencyPenalty,
|
||||
#[error("`top_p` must be > 0.0 and < 1.0")]
|
||||
TopP,
|
||||
#[error("`top_k` must be strictly positive")]
|
||||
|
|
|
@ -86,6 +86,7 @@ except ImportError as e:
|
|||
if MAMBA_AVAILABLE:
|
||||
__all__.append(Mamba)
|
||||
|
||||
|
||||
def get_model(
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
|
|
|
@ -696,14 +696,17 @@ class CausalLM(Model):
|
|||
|
||||
if top_n_tokens > 0:
|
||||
all_top_tokens = []
|
||||
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
|
||||
for (top_token_ids, top_token_logprobs) in zip(
|
||||
top_token_ids, top_token_logprobs
|
||||
):
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
token_id in self.all_special_ids
|
||||
for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = Tokens(
|
||||
top_token_ids,
|
||||
|
|
|
@ -19,6 +19,7 @@ from einops import rearrange
|
|||
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||
import math
|
||||
|
||||
|
||||
class MambaConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -53,6 +54,7 @@ class MambaConfig(PretrainedConfig):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class MambaBlock(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
@ -60,8 +62,12 @@ class MambaBlock(nn.Module):
|
|||
self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False)
|
||||
self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False)
|
||||
self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True)
|
||||
self.dt_proj_no_bias = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=False)
|
||||
self.out_proj = FastLinear.load(config, f"{prefix}.out_proj", weights, bias=False)
|
||||
self.dt_proj_no_bias = FastLinear.load(
|
||||
config, f"{prefix}.dt_proj", weights, bias=False
|
||||
)
|
||||
self.out_proj = FastLinear.load(
|
||||
config, f"{prefix}.out_proj", weights, bias=False
|
||||
)
|
||||
self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True)
|
||||
self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float())
|
||||
self.D = weights.get_tensor(f"{prefix}.D")
|
||||
|
@ -85,7 +91,9 @@ class MambaBlock(nn.Module):
|
|||
conv_state = F.pad(x, (self.d_conv - seqlen, 0))
|
||||
x = causal_conv1d_fn(
|
||||
x=x,
|
||||
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
|
||||
weight=self.conv1d.weight.view(
|
||||
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
|
||||
),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
)
|
||||
|
@ -94,7 +102,9 @@ class MambaBlock(nn.Module):
|
|||
# We want dt to have d as the slowest moving dimension
|
||||
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
||||
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
||||
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||
dt, B, C = torch.split(
|
||||
x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
|
||||
)
|
||||
dt = self.dt_proj.weight @ dt.t()
|
||||
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
||||
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||
|
@ -121,25 +131,36 @@ class MambaBlock(nn.Module):
|
|||
conv_state_new = torch.cat([conv_state, _x.transpose(1, 2)], dim=-1)
|
||||
conv_out = causal_conv1d_fn(
|
||||
x=conv_state_new,
|
||||
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
|
||||
weight=self.conv1d.weight.view(
|
||||
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
|
||||
),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation
|
||||
activation=self.activation,
|
||||
)
|
||||
conv_state = conv_state_new[:, :, 1:]
|
||||
bsz, seqlen, dim = hidden_states.shape
|
||||
output_tensor = torch.zeros(
|
||||
(bsz, seqlen, dim),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype
|
||||
(bsz, seqlen, dim), device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)
|
||||
for i in range(0, bsz):
|
||||
x = conv_out[i : i + 1, :, -1]
|
||||
z = _z[i : i + 1, -1, :]
|
||||
x_db = self.x_proj(x)
|
||||
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||
dt, B, C = torch.split(
|
||||
x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1
|
||||
)
|
||||
dt = F.linear(dt, self.dt_proj.weight)
|
||||
y = selective_state_update(
|
||||
ssm_state[i:i+1,:,:], x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
||||
ssm_state[i : i + 1, :, :],
|
||||
x,
|
||||
dt,
|
||||
self.negA,
|
||||
B,
|
||||
C,
|
||||
self.D,
|
||||
z=z,
|
||||
dt_bias=self.dt_proj.bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
out = self.out_proj(y)
|
||||
output_tensor[i] = out
|
||||
|
@ -147,12 +168,15 @@ class MambaBlock(nn.Module):
|
|||
return output_tensor, conv_state, ssm_state
|
||||
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
self.mamba_block = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights)
|
||||
self.layer_norm = FastRMSNorm.load(prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon)
|
||||
self.mamba_block = MambaBlock(
|
||||
prefix=f"{layer_id}.mixer", config=config, weights=weights
|
||||
)
|
||||
self.layer_norm = FastRMSNorm.load(
|
||||
prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -163,28 +187,47 @@ class ResidualBlock(nn.Module):
|
|||
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||
shape = residual.shape
|
||||
hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1]))
|
||||
hidden_states, conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params)
|
||||
hidden_states, conv_state, last_ssm_state = self.mamba_block(
|
||||
hidden_states.view(*shape), inference_params
|
||||
)
|
||||
return hidden_states, residual, conv_state, last_ssm_state
|
||||
|
||||
|
||||
class MambaModel(nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
prefix = "backbone"
|
||||
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
|
||||
self.blocks = nn.ModuleList(
|
||||
[ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)]
|
||||
[
|
||||
ResidualBlock(f"{prefix}.layers.{i}", config, weights)
|
||||
for i in range(config.n_layer)
|
||||
]
|
||||
)
|
||||
self.norm_f = FastRMSNorm.load(
|
||||
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.lm_head = FastLinear.load(
|
||||
config, f"{prefix}.embedding", weights, bias=False
|
||||
)
|
||||
self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon)
|
||||
self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False)
|
||||
self.config = config
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, inference_params=None, residual=None) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
|
||||
def forward(
|
||||
self, input_ids: torch.Tensor, inference_params=None, residual=None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for block in self.blocks:
|
||||
hidden_states, residual, conv_state, ssm_state = block(hidden_states, residual, inference_params)
|
||||
inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (conv_state, ssm_state)
|
||||
hidden_states, residual, conv_state, ssm_state = block(
|
||||
hidden_states, residual, inference_params
|
||||
)
|
||||
inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (
|
||||
conv_state,
|
||||
ssm_state,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + residual if residual is not None else hidden_states
|
||||
hidden_states = (
|
||||
hidden_states + residual if residual is not None else hidden_states
|
||||
)
|
||||
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
|
||||
hidden_states = hidden_states.view(residual.shape)
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
|
|
@ -842,7 +842,6 @@ class FlashCausalLM(Model):
|
|||
else:
|
||||
next_token_logits = out
|
||||
|
||||
|
||||
speculate = get_speculate()
|
||||
(
|
||||
next_input_ids,
|
||||
|
@ -1064,14 +1063,17 @@ class FlashCausalLM(Model):
|
|||
|
||||
if top_n_tokens > 0:
|
||||
all_top_tokens = []
|
||||
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
|
||||
for (top_token_ids, top_token_logprobs) in zip(
|
||||
top_token_ids, top_token_logprobs
|
||||
):
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
token_id in self.all_special_ids
|
||||
for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = Tokens(
|
||||
top_token_ids,
|
||||
|
|
|
@ -26,6 +26,7 @@ from dataclasses import dataclass
|
|||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
from mamba_ssm.utils.generation import InferenceParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaBatch(Batch):
|
||||
batch_id: int
|
||||
|
@ -221,7 +222,10 @@ class MambaBatch(Batch):
|
|||
# TODO
|
||||
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
|
||||
key_value_memory_dict = {}
|
||||
for i, (conv_state, ssm_state) in self.inference_params.key_value_memory_dict.items():
|
||||
for i, (
|
||||
conv_state,
|
||||
ssm_state,
|
||||
) in self.inference_params.key_value_memory_dict.items():
|
||||
key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices])
|
||||
self.inference_params.key_value_memory_dict = key_value_memory_dict
|
||||
|
||||
|
@ -305,8 +309,9 @@ class MambaBatch(Batch):
|
|||
|
||||
start_index = end_index
|
||||
|
||||
|
||||
(_, d_model, d_conv) = batches[0].inference_params.key_value_memory_dict[0][0].shape
|
||||
(_, d_model, d_conv) = (
|
||||
batches[0].inference_params.key_value_memory_dict[0][0].shape
|
||||
)
|
||||
(_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape
|
||||
n_blocks = len(batches[0].inference_params.key_value_memory_dict)
|
||||
dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype
|
||||
|
@ -344,9 +349,15 @@ class MambaBatch(Batch):
|
|||
for i in range(n_blocks):
|
||||
conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i]
|
||||
batch_size = batch.inference_params.max_batch_size
|
||||
inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state
|
||||
inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state
|
||||
inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample
|
||||
inference_params.key_value_memory_dict[i][0][
|
||||
current_batch : current_batch + batch_size
|
||||
] = conv_state
|
||||
inference_params.key_value_memory_dict[i][1][
|
||||
current_batch : current_batch + batch_size
|
||||
] = ssm_state
|
||||
inference_params.lengths_per_sample[
|
||||
current_batch : current_batch + batch_size
|
||||
] = batch.inference_params.lengths_per_sample
|
||||
current_batch += batch_size
|
||||
|
||||
return cls(
|
||||
|
@ -366,12 +377,13 @@ class MambaBatch(Batch):
|
|||
padding_right_offset=padding_right_offset,
|
||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||
max_tokens=max_tokens,
|
||||
inference_params=inference_params
|
||||
inference_params=inference_params,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
|
||||
|
||||
class Mamba(Model):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -441,7 +453,9 @@ class Mamba(Model):
|
|||
|
||||
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
||||
start = time.time_ns()
|
||||
input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
|
||||
input_ids = (
|
||||
batch.input_ids
|
||||
) # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
max_seqlen = input_ids.shape[1]
|
||||
|
@ -450,7 +464,10 @@ class Mamba(Model):
|
|||
# Inference params
|
||||
seqlen_og = 0
|
||||
inf_cache = {}
|
||||
lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen
|
||||
lengths_per_sample = (
|
||||
torch.ones(batch_size, dtype=torch.int32, device=input_ids.device)
|
||||
* max_seqlen
|
||||
)
|
||||
|
||||
if batch.inference_params is None:
|
||||
inference_params = InferenceParams(
|
||||
|
@ -478,11 +495,16 @@ class Mamba(Model):
|
|||
device=block.dt_proj.weight.device,
|
||||
dtype=block.dt_proj.weight.dtype,
|
||||
)
|
||||
inference_params.key_value_memory_dict[block.layer_idx] = (conv_state, ssm_state)
|
||||
inference_params.key_value_memory_dict[block.layer_idx] = (
|
||||
conv_state,
|
||||
ssm_state,
|
||||
)
|
||||
batch.inference_params = inference_params
|
||||
|
||||
# Forward pass
|
||||
logits, past_input_ids, new_inference_params = self.model(input_ids, batch.inference_params)
|
||||
logits, past_input_ids, new_inference_params = self.model(
|
||||
input_ids, batch.inference_params
|
||||
)
|
||||
|
||||
batch.inference_params = new_inference_params
|
||||
# Results
|
||||
|
@ -564,7 +586,8 @@ class Mamba(Model):
|
|||
prefix_offset=len(all_input_ids)
|
||||
- stopping_criteria.current_tokens
|
||||
- 1,
|
||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||
read_offset=len(all_input_ids)
|
||||
- stopping_criteria.current_tokens,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
# Get seed
|
||||
|
|
|
@ -750,14 +750,17 @@ class Seq2SeqLM(Model):
|
|||
|
||||
if top_n_tokens > 0:
|
||||
all_top_tokens = []
|
||||
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
|
||||
for (top_token_ids, top_token_logprobs) in zip(
|
||||
top_token_ids, top_token_logprobs
|
||||
):
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
token_id in self.all_special_ids
|
||||
for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = Tokens(
|
||||
top_token_ids,
|
||||
|
|
|
@ -95,5 +95,7 @@ class Generation:
|
|||
generated_text=self.generated_text.to_pb()
|
||||
if self.generated_text is not None
|
||||
else None,
|
||||
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens] if self.top_tokens is not None else None,
|
||||
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens]
|
||||
if self.top_tokens is not None
|
||||
else None,
|
||||
)
|
||||
|
|
|
@ -118,6 +118,62 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
|||
return None
|
||||
|
||||
|
||||
class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
Frequency penalty as defined by OpenAI
|
||||
|
||||
Args:
|
||||
penalty (`float`):
|
||||
The parameter for frequency penalty. 0.0 means no penalty.
|
||||
"""
|
||||
|
||||
def __init__(self, penalty: float):
|
||||
self.penalty = penalty
|
||||
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
||||
score = -torch.where(
|
||||
score < 0, score * self.penalty, score / self.penalty
|
||||
)
|
||||
|
||||
return scores.scatter_add_(1, input_ids, score)
|
||||
|
||||
|
||||
class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
Frequency penalty as defined by OpenAI
|
||||
|
||||
Args:
|
||||
frequency_penalty (`List[float]`):
|
||||
The parameter for frequency penalty. 0.0 means no penalty.
|
||||
"""
|
||||
|
||||
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
|
||||
self.penalty = penalty
|
||||
self.penalty_tensor = torch.tensor(
|
||||
penalty, dtype=dtype, device=device
|
||||
).unsqueeze(1)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
||||
score = -torch.where(
|
||||
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
||||
)
|
||||
|
||||
return scores.scatter_add_(1, input_ids, score)
|
||||
|
||||
def filter(self, indices):
|
||||
self.penalty = [self.penalty[i] for i in indices]
|
||||
if any([x != 0.0 for x in self.penalty]):
|
||||
self.penalty_tensor = self.penalty_tensor[indices]
|
||||
return self
|
||||
return None
|
||||
|
||||
|
||||
class HeterogeneousTemperatureLogitsWarper:
|
||||
r"""
|
||||
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
import re
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||
from text_generation_server.utils.logits_process import (
|
||||
FrequencyPenaltyLogitsProcessor,
|
||||
HeterogeneousProcessorWrapper,
|
||||
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
||||
HeterogeneousFrequencyPenaltyLogitsProcessor,
|
||||
HeterogeneousTemperatureLogitsWarper,
|
||||
HeterogeneousTopKLogitsWarper,
|
||||
HeterogeneousTopPLogitsWarper,
|
||||
|
@ -23,6 +25,7 @@ class NextTokenChooser:
|
|||
watermark=False,
|
||||
temperature=1.0,
|
||||
repetition_penalty=1.0,
|
||||
frequency_penalty=0.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
typical_p=None,
|
||||
|
@ -35,7 +38,12 @@ class NextTokenChooser:
|
|||
)
|
||||
self.repetition_processor = (
|
||||
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
|
||||
if repetition_penalty
|
||||
if repetition_penalty and repetition_penalty != 1.0
|
||||
else None
|
||||
)
|
||||
self.frequency_processor = (
|
||||
FrequencyPenaltyLogitsProcessor(penalty=frequency_penalty)
|
||||
if frequency_penalty and frequency_penalty != 0.0
|
||||
else None
|
||||
)
|
||||
|
||||
|
@ -60,6 +68,8 @@ class NextTokenChooser:
|
|||
scores = self.watermark_processor(input_ids, scores)
|
||||
if self.repetition_processor is not None:
|
||||
scores = self.repetition_processor(input_ids, scores)
|
||||
if self.frequency_processor is not None:
|
||||
scores = self.frequency_processor(input_ids, scores)
|
||||
|
||||
if self.static_warper is None:
|
||||
next_logprob = torch.log_softmax(scores, -1)
|
||||
|
@ -80,6 +90,7 @@ class NextTokenChooser:
|
|||
watermark=pb.watermark,
|
||||
temperature=pb.temperature,
|
||||
repetition_penalty=pb.repetition_penalty,
|
||||
frequency_penalty=pb.frequency_penalty,
|
||||
top_k=pb.top_k,
|
||||
top_p=pb.top_p,
|
||||
typical_p=pb.typical_p,
|
||||
|
@ -184,6 +195,7 @@ class HeterogeneousNextTokenChooser:
|
|||
watermark: List[bool],
|
||||
temperature: List[float],
|
||||
repetition_penalty: List[float],
|
||||
frequency_penalty: List[float],
|
||||
top_k: List[int],
|
||||
top_p: List[float],
|
||||
typical_p: List[float],
|
||||
|
@ -212,6 +224,14 @@ class HeterogeneousNextTokenChooser:
|
|||
else None
|
||||
)
|
||||
|
||||
self.frequency_processor = (
|
||||
HeterogeneousFrequencyPenaltyLogitsProcessor(
|
||||
frequency_penalty, dtype, device
|
||||
)
|
||||
if any([x != 0.0 for x in frequency_penalty])
|
||||
else None
|
||||
)
|
||||
|
||||
if any([x != 1.0 for x in temperature]):
|
||||
do_sample = [
|
||||
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||
|
@ -269,6 +289,8 @@ class HeterogeneousNextTokenChooser:
|
|||
_scores = self.watermark_processor(input_ids, _scores)
|
||||
if self.repetition_processor is not None:
|
||||
_scores = self.repetition_processor(input_ids, _scores)
|
||||
if self.frequency_processor is not None:
|
||||
_scores = self.frequency_processor(input_ids, _scores)
|
||||
|
||||
for warper in self.warpers:
|
||||
_scores = warper(input_ids, _scores)
|
||||
|
@ -316,7 +338,6 @@ class HeterogeneousNextTokenChooser:
|
|||
|
||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||
|
||||
|
||||
if speculate > 0:
|
||||
if speculative_scores is not None:
|
||||
# Medusa provided some scores
|
||||
|
@ -338,6 +359,9 @@ class HeterogeneousNextTokenChooser:
|
|||
if self.repetition_processor is not None:
|
||||
self.repetition_processor = self.repetition_processor.filter(indices)
|
||||
|
||||
if self.frequency_processor is not None:
|
||||
self.frequency_processor = self.frequency_processor.filter(indices)
|
||||
|
||||
filtered_warpers = []
|
||||
for warper in self.warpers:
|
||||
filtered_warper = warper.filter(indices)
|
||||
|
@ -366,6 +390,7 @@ class HeterogeneousNextTokenChooser:
|
|||
watermark=[pb_.watermark for pb_ in pb],
|
||||
temperature=[pb_.temperature for pb_ in pb],
|
||||
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
|
||||
frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
|
||||
top_k=[pb_.top_k for pb_ in pb],
|
||||
top_p=[pb_.top_p for pb_ in pb],
|
||||
typical_p=[pb_.typical_p for pb_ in pb],
|
||||
|
@ -438,7 +463,10 @@ class HeterogeneousSampling:
|
|||
|
||||
|
||||
def batch_top_tokens(
|
||||
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor
|
||||
top_n_tokens: List[int],
|
||||
top_n_tokens_tensor: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
accepted_ids: torch.Tensor,
|
||||
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
|
||||
"""Find the top n most likely tokens for a batch of generations.
|
||||
|
||||
|
@ -450,12 +478,15 @@ def batch_top_tokens(
|
|||
if max_top_n == 0:
|
||||
return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
|
||||
|
||||
|
||||
batch_size = accepted_ids.shape[0]
|
||||
speculate_size = logprobs.shape[0] // batch_size
|
||||
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
|
||||
# Ensure top_n doesn't exceed vocab size
|
||||
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculate_size)]
|
||||
top_n_tokens = [
|
||||
min(tok, logprobs.size(-1))
|
||||
for tok in top_n_tokens
|
||||
for _ in range(speculate_size)
|
||||
]
|
||||
|
||||
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
||||
# Sorted topk is faster than torch.sort() since we only need a small subset
|
||||
|
@ -497,7 +528,9 @@ def batch_top_tokens(
|
|||
row_top_token_ids = []
|
||||
row_top_token_logprobs = []
|
||||
|
||||
for idxs, vals, n, req_n in zip(_top_indices, _top_values, _top_n_ishes, _top_n_tokens):
|
||||
for idxs, vals, n, req_n in zip(
|
||||
_top_indices, _top_values, _top_n_ishes, _top_n_tokens
|
||||
):
|
||||
indices = idxs[:n] if req_n > 0 else []
|
||||
values = vals[:n] if req_n > 0 else []
|
||||
|
||||
|
|
Loading…
Reference in New Issue