feat(server): add frequency penalty (#1541)
This commit is contained in:
parent
39af000cb9
commit
09b7c26bbd
|
@ -2787,7 +2787,7 @@ dependencies = [
|
||||||
"tabled",
|
"tabled",
|
||||||
"text-generation-client",
|
"text-generation-client",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers",
|
"tokenizers 0.14.1",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
|
@ -2850,7 +2850,7 @@ dependencies = [
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"text-generation-client",
|
"text-generation-client",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers",
|
"tokenizers 0.15.1",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
|
@ -2972,6 +2972,40 @@ dependencies = [
|
||||||
"unicode_categories",
|
"unicode_categories",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokenizers"
|
||||||
|
version = "0.15.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6db445cceba5dfeb0f9702be7d6bfd91801ddcbe8fe8722defe7f2e96da75812"
|
||||||
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
|
"clap",
|
||||||
|
"derive_builder",
|
||||||
|
"esaxx-rs",
|
||||||
|
"getrandom",
|
||||||
|
"hf-hub",
|
||||||
|
"indicatif",
|
||||||
|
"itertools 0.11.0",
|
||||||
|
"lazy_static",
|
||||||
|
"log",
|
||||||
|
"macro_rules_attribute",
|
||||||
|
"monostate",
|
||||||
|
"onig",
|
||||||
|
"paste",
|
||||||
|
"rand",
|
||||||
|
"rayon",
|
||||||
|
"rayon-cond",
|
||||||
|
"regex",
|
||||||
|
"regex-syntax 0.7.5",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"spm_precompiled",
|
||||||
|
"thiserror",
|
||||||
|
"unicode-normalization-alignments",
|
||||||
|
"unicode-segmentation",
|
||||||
|
"unicode_categories",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio"
|
name = "tokio"
|
||||||
version = "1.35.1"
|
version = "1.35.1"
|
||||||
|
|
|
@ -30,6 +30,7 @@ pub async fn run(
|
||||||
top_p: Option<f32>,
|
top_p: Option<f32>,
|
||||||
typical_p: Option<f32>,
|
typical_p: Option<f32>,
|
||||||
repetition_penalty: Option<f32>,
|
repetition_penalty: Option<f32>,
|
||||||
|
frequency_penalty: Option<f32>,
|
||||||
watermark: bool,
|
watermark: bool,
|
||||||
do_sample: bool,
|
do_sample: bool,
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
|
@ -42,6 +43,7 @@ pub async fn run(
|
||||||
do_sample,
|
do_sample,
|
||||||
seed: 0,
|
seed: 0,
|
||||||
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
||||||
|
frequency_penalty: frequency_penalty.unwrap_or(0.0),
|
||||||
watermark,
|
watermark,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -140,6 +142,7 @@ pub async fn run(
|
||||||
top_p,
|
top_p,
|
||||||
typical_p,
|
typical_p,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
|
frequency_penalty,
|
||||||
watermark,
|
watermark,
|
||||||
do_sample,
|
do_sample,
|
||||||
);
|
);
|
||||||
|
|
|
@ -84,6 +84,11 @@ struct Args {
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
repetition_penalty: Option<f32>,
|
repetition_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Generation parameter in case you want to specifically test/debug particular
|
||||||
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||||
|
#[clap(long, env)]
|
||||||
|
frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
/// Generation parameter in case you want to specifically test/debug particular
|
/// Generation parameter in case you want to specifically test/debug particular
|
||||||
/// decoding strategies, for full doc refer to the `text-generation-server`
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
|
@ -119,6 +124,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
top_p,
|
top_p,
|
||||||
typical_p,
|
typical_p,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
|
frequency_penalty,
|
||||||
watermark,
|
watermark,
|
||||||
do_sample,
|
do_sample,
|
||||||
master_shard_uds_path,
|
master_shard_uds_path,
|
||||||
|
@ -187,6 +193,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
top_p,
|
top_p,
|
||||||
typical_p,
|
typical_p,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
|
frequency_penalty,
|
||||||
watermark,
|
watermark,
|
||||||
do_sample,
|
do_sample,
|
||||||
sharded_client,
|
sharded_client,
|
||||||
|
|
|
@ -15,6 +15,7 @@ pub(crate) fn parameters_table(
|
||||||
top_p: Option<f32>,
|
top_p: Option<f32>,
|
||||||
typical_p: Option<f32>,
|
typical_p: Option<f32>,
|
||||||
repetition_penalty: Option<f32>,
|
repetition_penalty: Option<f32>,
|
||||||
|
frequency_penalty: Option<f32>,
|
||||||
watermark: bool,
|
watermark: bool,
|
||||||
do_sample: bool,
|
do_sample: bool,
|
||||||
) -> Table {
|
) -> Table {
|
||||||
|
@ -33,6 +34,7 @@ pub(crate) fn parameters_table(
|
||||||
builder.push_record(["Top P", &format!("{top_p:?}")]);
|
builder.push_record(["Top P", &format!("{top_p:?}")]);
|
||||||
builder.push_record(["Typical P", &format!("{typical_p:?}")]);
|
builder.push_record(["Typical P", &format!("{typical_p:?}")]);
|
||||||
builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]);
|
builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]);
|
||||||
|
builder.push_record(["Frequency Penalty", &format!("{frequency_penalty:?}")]);
|
||||||
builder.push_record(["Watermark", &watermark.to_string()]);
|
builder.push_record(["Watermark", &watermark.to_string()]);
|
||||||
builder.push_record(["Do Sample", &do_sample.to_string()]);
|
builder.push_record(["Do Sample", &do_sample.to_string()]);
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
|
||||||
assert response.generated_text == "\n\nDeep learning is a new type of machine"
|
assert response.generated_text == "\n\nDeep learning is a new type of machine"
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
||||||
|
@ -44,13 +45,19 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in"
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== "blue, red, yellow, \nand orange (in the order they appear in"
|
||||||
|
)
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
|
async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
|
||||||
responses = await generate_load(fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4)
|
responses = await generate_load(
|
||||||
|
fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
|
@ -66,6 +66,8 @@ message NextTokenChooserParameters {
|
||||||
uint64 seed = 6;
|
uint64 seed = 6;
|
||||||
/// repetition penalty
|
/// repetition penalty
|
||||||
float repetition_penalty = 7;
|
float repetition_penalty = 7;
|
||||||
|
/// frequency penalty
|
||||||
|
float frequency_penalty = 9;
|
||||||
/// token watermarking using "A Watermark for Large Language Models"
|
/// token watermarking using "A Watermark for Large Language Models"
|
||||||
bool watermark = 8;
|
bool watermark = 8;
|
||||||
}
|
}
|
||||||
|
|
|
@ -125,6 +125,7 @@ impl Client {
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
seed: 0,
|
seed: 0,
|
||||||
repetition_penalty: 1.2,
|
repetition_penalty: 1.2,
|
||||||
|
frequency_penalty: 0.1,
|
||||||
watermark: true,
|
watermark: true,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
|
|
@ -43,6 +43,7 @@ impl Health {
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
seed: 0,
|
seed: 0,
|
||||||
repetition_penalty: 1.0,
|
repetition_penalty: 1.0,
|
||||||
|
frequency_penalty: 0.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
|
|
@ -106,6 +106,14 @@ pub(crate) struct GenerateParameters {
|
||||||
)]
|
)]
|
||||||
pub repetition_penalty: Option<f32>,
|
pub repetition_penalty: Option<f32>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(
|
||||||
|
exclusive_minimum = -2.0,
|
||||||
|
nullable = true,
|
||||||
|
default = "null",
|
||||||
|
example = 0.1
|
||||||
|
)]
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
#[serde(default)]
|
||||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
|
||||||
pub top_k: Option<i32>,
|
pub top_k: Option<i32>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
@ -172,6 +180,7 @@ fn default_parameters() -> GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature: None,
|
temperature: None,
|
||||||
repetition_penalty: None,
|
repetition_penalty: None,
|
||||||
|
frequency_penalty: None,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: None,
|
top_p: None,
|
||||||
typical_p: None,
|
typical_p: None,
|
||||||
|
@ -205,10 +214,71 @@ pub(crate) struct ChatCompletion {
|
||||||
pub(crate) struct ChatCompletionComplete {
|
pub(crate) struct ChatCompletionComplete {
|
||||||
pub index: u32,
|
pub index: u32,
|
||||||
pub message: Message,
|
pub message: Message,
|
||||||
pub logprobs: Option<Vec<f32>>,
|
pub logprobs: Option<ChatCompletionLogprobs>,
|
||||||
pub finish_reason: String,
|
pub finish_reason: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
|
pub(crate) struct ChatCompletionLogprobs {
|
||||||
|
content: Vec<ChatCompletionLogprob>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<(Token, Vec<Token>)> for ChatCompletionLogprobs {
|
||||||
|
fn from(value: (Token, Vec<Token>)) -> Self {
|
||||||
|
let (token, top_tokens) = value;
|
||||||
|
|
||||||
|
Self {
|
||||||
|
content: vec![ChatCompletionLogprob {
|
||||||
|
token: token.text,
|
||||||
|
logprob: token.logprob,
|
||||||
|
top_logprobs: top_tokens
|
||||||
|
.into_iter()
|
||||||
|
.map(|t| ChatCompletionTopLogprob {
|
||||||
|
token: t.text,
|
||||||
|
logprob: t.logprob,
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<(Vec<Token>, Vec<Vec<Token>>)> for ChatCompletionLogprobs {
|
||||||
|
fn from(value: (Vec<Token>, Vec<Vec<Token>>)) -> Self {
|
||||||
|
let (tokens, top_tokens) = value;
|
||||||
|
Self {
|
||||||
|
content: tokens
|
||||||
|
.into_iter()
|
||||||
|
.zip(top_tokens)
|
||||||
|
.map(|(t, top_t)| ChatCompletionLogprob {
|
||||||
|
token: t.text,
|
||||||
|
logprob: t.logprob,
|
||||||
|
top_logprobs: top_t
|
||||||
|
.into_iter()
|
||||||
|
.map(|t| ChatCompletionTopLogprob {
|
||||||
|
token: t.text,
|
||||||
|
logprob: t.logprob,
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
|
pub(crate) struct ChatCompletionLogprob {
|
||||||
|
token: String,
|
||||||
|
logprob: f32,
|
||||||
|
top_logprobs: Vec<ChatCompletionTopLogprob>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
|
pub(crate) struct ChatCompletionTopLogprob {
|
||||||
|
token: String,
|
||||||
|
logprob: f32,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize)]
|
#[derive(Clone, Deserialize, Serialize)]
|
||||||
pub(crate) struct Usage {
|
pub(crate) struct Usage {
|
||||||
pub prompt_tokens: u32,
|
pub prompt_tokens: u32,
|
||||||
|
@ -238,7 +308,7 @@ impl ChatCompletion {
|
||||||
content: output,
|
content: output,
|
||||||
},
|
},
|
||||||
logprobs: return_logprobs
|
logprobs: return_logprobs
|
||||||
.then(|| details.tokens.iter().map(|t| t.logprob).collect()),
|
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
|
||||||
finish_reason: details.finish_reason.to_string(),
|
finish_reason: details.finish_reason.to_string(),
|
||||||
}],
|
}],
|
||||||
usage: Usage {
|
usage: Usage {
|
||||||
|
@ -266,7 +336,7 @@ pub(crate) struct ChatCompletionChunk {
|
||||||
pub(crate) struct ChatCompletionChoice {
|
pub(crate) struct ChatCompletionChoice {
|
||||||
pub index: u32,
|
pub index: u32,
|
||||||
pub delta: ChatCompletionDelta,
|
pub delta: ChatCompletionDelta,
|
||||||
pub logprobs: Option<f32>,
|
pub logprobs: Option<ChatCompletionLogprobs>,
|
||||||
pub finish_reason: Option<String>,
|
pub finish_reason: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -285,7 +355,7 @@ impl ChatCompletionChunk {
|
||||||
delta: String,
|
delta: String,
|
||||||
created: u64,
|
created: u64,
|
||||||
index: u32,
|
index: u32,
|
||||||
logprobs: Option<f32>,
|
logprobs: Option<ChatCompletionLogprobs>,
|
||||||
finish_reason: Option<String>,
|
finish_reason: Option<String>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
@ -319,8 +389,8 @@ pub(crate) struct ChatRequest {
|
||||||
/// UNUSED
|
/// UNUSED
|
||||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||||||
pub model: String, /* NOTE: UNUSED */
|
pub model: String,
|
||||||
|
/* NOTE: UNUSED */
|
||||||
/// A list of messages comprising the conversation so far.
|
/// A list of messages comprising the conversation so far.
|
||||||
#[serde(default = "default_request_messages")]
|
#[serde(default = "default_request_messages")]
|
||||||
pub messages: Vec<Message>,
|
pub messages: Vec<Message>,
|
||||||
|
@ -346,7 +416,6 @@ pub(crate) struct ChatRequest {
|
||||||
#[schema(example = "false")]
|
#[schema(example = "false")]
|
||||||
pub logprobs: Option<bool>,
|
pub logprobs: Option<bool>,
|
||||||
|
|
||||||
/// UNUSED
|
|
||||||
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
|
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
|
||||||
/// an associated log probability. logprobs must be set to true if this parameter is used.
|
/// an associated log probability. logprobs must be set to true if this parameter is used.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
@ -365,7 +434,6 @@ pub(crate) struct ChatRequest {
|
||||||
#[schema(nullable = true, example = "2")]
|
#[schema(nullable = true, example = "2")]
|
||||||
pub n: Option<u32>,
|
pub n: Option<u32>,
|
||||||
|
|
||||||
/// UNUSED
|
|
||||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
|
||||||
/// increasing the model's likelihood to talk about new topics
|
/// increasing the model's likelihood to talk about new topics
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
@ -447,7 +515,7 @@ pub struct PrefillToken {
|
||||||
logprob: f32,
|
logprob: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, ToSchema)]
|
#[derive(Debug, Serialize, ToSchema, Clone)]
|
||||||
pub struct Token {
|
pub struct Token {
|
||||||
#[schema(example = 0)]
|
#[schema(example = 0)]
|
||||||
id: u32,
|
id: u32,
|
||||||
|
|
|
@ -355,6 +355,7 @@ mod tests {
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
seed: 0,
|
seed: 0,
|
||||||
repetition_penalty: 0.0,
|
repetition_penalty: 0.0,
|
||||||
|
frequency_penalty: 0.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
},
|
},
|
||||||
stopping_parameters: StoppingCriteriaParameters {
|
stopping_parameters: StoppingCriteriaParameters {
|
||||||
|
|
|
@ -4,9 +4,10 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
||||||
ChatRequest, CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
|
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
|
||||||
GenerateRequest, GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message,
|
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
|
||||||
PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Validation,
|
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||||
|
StreamResponse, Token, TokenizeResponse, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
|
@ -570,8 +571,8 @@ async fn chat_completions(
|
||||||
let stream = req.stream;
|
let stream = req.stream;
|
||||||
let max_new_tokens = req.max_tokens.or(Some(100));
|
let max_new_tokens = req.max_tokens.or(Some(100));
|
||||||
let repetition_penalty = req
|
let repetition_penalty = req
|
||||||
.frequency_penalty
|
.presence_penalty
|
||||||
// rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0)
|
// rescale repetition_penalty from (-2.0, 2.0) to (0.0, 4.0)
|
||||||
.map(|x| x + 2.0);
|
.map(|x| x + 2.0);
|
||||||
let logprobs = req.logprobs.unwrap_or(false);
|
let logprobs = req.logprobs.unwrap_or(false);
|
||||||
let seed = req.seed;
|
let seed = req.seed;
|
||||||
|
@ -599,6 +600,7 @@ async fn chat_completions(
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature: req.temperature,
|
temperature: req.temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
|
frequency_penalty: req.frequency_penalty,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: req.top_p,
|
top_p: req.top_p,
|
||||||
typical_p: None,
|
typical_p: None,
|
||||||
|
@ -630,6 +632,10 @@ async fn chat_completions(
|
||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
.as_secs();
|
.as_secs();
|
||||||
|
|
||||||
|
let logprobs = logprobs.then(|| {
|
||||||
|
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
|
||||||
|
});
|
||||||
|
|
||||||
event
|
event
|
||||||
.json_data(ChatCompletionChunk::new(
|
.json_data(ChatCompletionChunk::new(
|
||||||
model_id.clone(),
|
model_id.clone(),
|
||||||
|
@ -637,7 +643,7 @@ async fn chat_completions(
|
||||||
stream_token.token.text,
|
stream_token.token.text,
|
||||||
current_time,
|
current_time,
|
||||||
stream_token.index,
|
stream_token.index,
|
||||||
logprobs.then_some(stream_token.token.logprob),
|
logprobs,
|
||||||
stream_token.details.map(|d| d.finish_reason.to_string()),
|
stream_token.details.map(|d| d.finish_reason.to_string()),
|
||||||
))
|
))
|
||||||
.map_or_else(
|
.map_or_else(
|
||||||
|
|
|
@ -170,6 +170,7 @@ impl Validation {
|
||||||
best_of,
|
best_of,
|
||||||
temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
|
frequency_penalty,
|
||||||
top_k,
|
top_k,
|
||||||
top_p,
|
top_p,
|
||||||
typical_p,
|
typical_p,
|
||||||
|
@ -206,6 +207,11 @@ impl Validation {
|
||||||
return Err(ValidationError::RepetitionPenalty);
|
return Err(ValidationError::RepetitionPenalty);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let frequency_penalty = frequency_penalty.unwrap_or(0.0);
|
||||||
|
if !(-2.0..=2.0).contains(&frequency_penalty) {
|
||||||
|
return Err(ValidationError::FrequencyPenalty);
|
||||||
|
}
|
||||||
|
|
||||||
// Different because the proto default value is not a valid value
|
// Different because the proto default value is not a valid value
|
||||||
// for the user
|
// for the user
|
||||||
let top_p = top_p
|
let top_p = top_p
|
||||||
|
@ -289,6 +295,7 @@ impl Validation {
|
||||||
let parameters = NextTokenChooserParameters {
|
let parameters = NextTokenChooserParameters {
|
||||||
temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
|
frequency_penalty,
|
||||||
top_k,
|
top_k,
|
||||||
top_p,
|
top_p,
|
||||||
typical_p,
|
typical_p,
|
||||||
|
@ -420,6 +427,8 @@ pub enum ValidationError {
|
||||||
Temperature,
|
Temperature,
|
||||||
#[error("`repetition_penalty` must be strictly positive")]
|
#[error("`repetition_penalty` must be strictly positive")]
|
||||||
RepetitionPenalty,
|
RepetitionPenalty,
|
||||||
|
#[error("`frequency_penalty` must be >= -2.0 and <= 2.0")]
|
||||||
|
FrequencyPenalty,
|
||||||
#[error("`top_p` must be > 0.0 and < 1.0")]
|
#[error("`top_p` must be > 0.0 and < 1.0")]
|
||||||
TopP,
|
TopP,
|
||||||
#[error("`top_k` must be strictly positive")]
|
#[error("`top_k` must be strictly positive")]
|
||||||
|
|
|
@ -70,7 +70,7 @@ def test_batch_top_tokens():
|
||||||
|
|
||||||
# Now let's make second member of the batch be speculated
|
# Now let's make second member of the batch be speculated
|
||||||
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2)
|
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2)
|
||||||
accepted_ids[1] = 2
|
accepted_ids[1] = 2
|
||||||
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
|
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
|
||||||
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
|
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
|
||||||
)
|
)
|
||||||
|
|
|
@ -86,6 +86,7 @@ except ImportError as e:
|
||||||
if MAMBA_AVAILABLE:
|
if MAMBA_AVAILABLE:
|
||||||
__all__.append(Mamba)
|
__all__.append(Mamba)
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
|
|
|
@ -696,14 +696,17 @@ class CausalLM(Model):
|
||||||
|
|
||||||
if top_n_tokens > 0:
|
if top_n_tokens > 0:
|
||||||
all_top_tokens = []
|
all_top_tokens = []
|
||||||
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
|
for (top_token_ids, top_token_logprobs) in zip(
|
||||||
|
top_token_ids, top_token_logprobs
|
||||||
|
):
|
||||||
toptoken_texts = self.tokenizer.batch_decode(
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
top_token_ids,
|
top_token_ids,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
special_toptokens = [
|
special_toptokens = [
|
||||||
token_id in self.all_special_ids for token_id in top_token_ids
|
token_id in self.all_special_ids
|
||||||
|
for token_id in top_token_ids
|
||||||
]
|
]
|
||||||
top_tokens = Tokens(
|
top_tokens = Tokens(
|
||||||
top_token_ids,
|
top_token_ids,
|
||||||
|
|
|
@ -19,6 +19,7 @@ from einops import rearrange
|
||||||
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
|
||||||
class MambaConfig(PretrainedConfig):
|
class MambaConfig(PretrainedConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -53,6 +54,7 @@ class MambaConfig(PretrainedConfig):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MambaBlock(nn.Module):
|
class MambaBlock(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -60,10 +62,14 @@ class MambaBlock(nn.Module):
|
||||||
self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False)
|
self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False)
|
||||||
self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False)
|
self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False)
|
||||||
self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True)
|
self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True)
|
||||||
self.dt_proj_no_bias = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=False)
|
self.dt_proj_no_bias = FastLinear.load(
|
||||||
self.out_proj = FastLinear.load(config, f"{prefix}.out_proj", weights, bias=False)
|
config, f"{prefix}.dt_proj", weights, bias=False
|
||||||
|
)
|
||||||
|
self.out_proj = FastLinear.load(
|
||||||
|
config, f"{prefix}.out_proj", weights, bias=False
|
||||||
|
)
|
||||||
self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True)
|
self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True)
|
||||||
self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float())
|
self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float())
|
||||||
self.D = weights.get_tensor(f"{prefix}.D")
|
self.D = weights.get_tensor(f"{prefix}.D")
|
||||||
self.activation = "silu"
|
self.activation = "silu"
|
||||||
self.dt_rank = config.dt_rank
|
self.dt_rank = config.dt_rank
|
||||||
|
@ -80,12 +86,14 @@ class MambaBlock(nn.Module):
|
||||||
out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state)
|
out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state)
|
||||||
return out, conv_state, ssm_state
|
return out, conv_state, ssm_state
|
||||||
|
|
||||||
projected_states = self.in_proj(hidden_states).transpose(1,2)
|
projected_states = self.in_proj(hidden_states).transpose(1, 2)
|
||||||
x, z = projected_states.chunk(2, dim=1)
|
x, z = projected_states.chunk(2, dim=1)
|
||||||
conv_state = F.pad(x, (self.d_conv - seqlen, 0))
|
conv_state = F.pad(x, (self.d_conv - seqlen, 0))
|
||||||
x = causal_conv1d_fn(
|
x = causal_conv1d_fn(
|
||||||
x=x,
|
x=x,
|
||||||
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
|
weight=self.conv1d.weight.view(
|
||||||
|
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
|
||||||
|
),
|
||||||
bias=self.conv1d.bias,
|
bias=self.conv1d.bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
)
|
)
|
||||||
|
@ -94,7 +102,9 @@ class MambaBlock(nn.Module):
|
||||||
# We want dt to have d as the slowest moving dimension
|
# We want dt to have d as the slowest moving dimension
|
||||||
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
||||||
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
||||||
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
dt, B, C = torch.split(
|
||||||
|
x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
|
||||||
|
)
|
||||||
dt = self.dt_proj.weight @ dt.t()
|
dt = self.dt_proj.weight @ dt.t()
|
||||||
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
||||||
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||||
|
@ -118,28 +128,39 @@ class MambaBlock(nn.Module):
|
||||||
def step(self, hidden_states, conv_state, ssm_state):
|
def step(self, hidden_states, conv_state, ssm_state):
|
||||||
_xz = self.in_proj(hidden_states)
|
_xz = self.in_proj(hidden_states)
|
||||||
_x, _z = _xz.chunk(2, dim=-1) # (B D)
|
_x, _z = _xz.chunk(2, dim=-1) # (B D)
|
||||||
conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1)
|
conv_state_new = torch.cat([conv_state, _x.transpose(1, 2)], dim=-1)
|
||||||
conv_out = causal_conv1d_fn(
|
conv_out = causal_conv1d_fn(
|
||||||
x=conv_state_new,
|
x=conv_state_new,
|
||||||
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
|
weight=self.conv1d.weight.view(
|
||||||
bias=self.conv1d.bias,
|
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
|
||||||
activation=self.activation
|
),
|
||||||
|
bias=self.conv1d.bias,
|
||||||
|
activation=self.activation,
|
||||||
)
|
)
|
||||||
conv_state = conv_state_new[:, :, 1:]
|
conv_state = conv_state_new[:, :, 1:]
|
||||||
bsz, seqlen, dim = hidden_states.shape
|
bsz, seqlen, dim = hidden_states.shape
|
||||||
output_tensor = torch.zeros(
|
output_tensor = torch.zeros(
|
||||||
(bsz, seqlen, dim),
|
(bsz, seqlen, dim), device=hidden_states.device, dtype=hidden_states.dtype
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype
|
|
||||||
)
|
)
|
||||||
for i in range(0, bsz):
|
for i in range(0, bsz):
|
||||||
x = conv_out[i:i+1,:,-1]
|
x = conv_out[i : i + 1, :, -1]
|
||||||
z = _z[i:i+1, -1, :]
|
z = _z[i : i + 1, -1, :]
|
||||||
x_db = self.x_proj(x)
|
x_db = self.x_proj(x)
|
||||||
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
dt, B, C = torch.split(
|
||||||
|
x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1
|
||||||
|
)
|
||||||
dt = F.linear(dt, self.dt_proj.weight)
|
dt = F.linear(dt, self.dt_proj.weight)
|
||||||
y = selective_state_update(
|
y = selective_state_update(
|
||||||
ssm_state[i:i+1,:,:], x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
ssm_state[i : i + 1, :, :],
|
||||||
|
x,
|
||||||
|
dt,
|
||||||
|
self.negA,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
self.D,
|
||||||
|
z=z,
|
||||||
|
dt_bias=self.dt_proj.bias,
|
||||||
|
dt_softplus=True,
|
||||||
)
|
)
|
||||||
out = self.out_proj(y)
|
out = self.out_proj(y)
|
||||||
output_tensor[i] = out
|
output_tensor[i] = out
|
||||||
|
@ -147,48 +168,70 @@ class MambaBlock(nn.Module):
|
||||||
return output_tensor, conv_state, ssm_state
|
return output_tensor, conv_state, ssm_state
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mamba_block = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights)
|
self.mamba_block = MambaBlock(
|
||||||
self.layer_norm = FastRMSNorm.load(prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon)
|
prefix=f"{layer_id}.mixer", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.layer_norm = FastRMSNorm.load(
|
||||||
|
prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor] = None,
|
residual: Optional[torch.Tensor] = None,
|
||||||
inference_params: Optional[Any] = None,
|
inference_params: Optional[Any] = None,
|
||||||
):
|
):
|
||||||
residual = (hidden_states + residual) if residual is not None else hidden_states
|
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||||
shape = residual.shape
|
shape = residual.shape
|
||||||
hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1]))
|
hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1]))
|
||||||
hidden_states, conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params)
|
hidden_states, conv_state, last_ssm_state = self.mamba_block(
|
||||||
|
hidden_states.view(*shape), inference_params
|
||||||
|
)
|
||||||
return hidden_states, residual, conv_state, last_ssm_state
|
return hidden_states, residual, conv_state, last_ssm_state
|
||||||
|
|
||||||
|
|
||||||
class MambaModel(nn.Module):
|
class MambaModel(nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = "backbone"
|
prefix = "backbone"
|
||||||
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
|
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)]
|
[
|
||||||
|
ResidualBlock(f"{prefix}.layers.{i}", config, weights)
|
||||||
|
for i in range(config.n_layer)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm_f = FastRMSNorm.load(
|
||||||
|
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
|
||||||
|
)
|
||||||
|
self.lm_head = FastLinear.load(
|
||||||
|
config, f"{prefix}.embedding", weights, bias=False
|
||||||
)
|
)
|
||||||
self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon)
|
|
||||||
self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False)
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def forward(self, input_ids: torch.Tensor, inference_params=None, residual=None) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
|
def forward(
|
||||||
|
self, input_ids: torch.Tensor, inference_params=None, residual=None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
hidden_states, residual, conv_state, ssm_state = block(hidden_states, residual, inference_params)
|
hidden_states, residual, conv_state, ssm_state = block(
|
||||||
inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (conv_state, ssm_state)
|
hidden_states, residual, inference_params
|
||||||
|
)
|
||||||
|
inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (
|
||||||
|
conv_state,
|
||||||
|
ssm_state,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states + residual if residual is not None else hidden_states
|
hidden_states = (
|
||||||
|
hidden_states + residual if residual is not None else hidden_states
|
||||||
|
)
|
||||||
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
|
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
|
||||||
hidden_states = hidden_states.view(residual.shape)
|
hidden_states = hidden_states.view(residual.shape)
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
# update the offset for the next inference using these params
|
# update the offset for the next inference using these params
|
||||||
inference_params.seqlen_offset += input_ids.size(1)
|
inference_params.seqlen_offset += input_ids.size(1)
|
||||||
return logits, input_ids, inference_params
|
return logits, input_ids, inference_params
|
||||||
|
|
|
@ -842,7 +842,6 @@ class FlashCausalLM(Model):
|
||||||
else:
|
else:
|
||||||
next_token_logits = out
|
next_token_logits = out
|
||||||
|
|
||||||
|
|
||||||
speculate = get_speculate()
|
speculate = get_speculate()
|
||||||
(
|
(
|
||||||
next_input_ids,
|
next_input_ids,
|
||||||
|
@ -1064,14 +1063,17 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
if top_n_tokens > 0:
|
if top_n_tokens > 0:
|
||||||
all_top_tokens = []
|
all_top_tokens = []
|
||||||
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
|
for (top_token_ids, top_token_logprobs) in zip(
|
||||||
|
top_token_ids, top_token_logprobs
|
||||||
|
):
|
||||||
toptoken_texts = self.tokenizer.batch_decode(
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
top_token_ids,
|
top_token_ids,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
special_toptokens = [
|
special_toptokens = [
|
||||||
token_id in self.all_special_ids for token_id in top_token_ids
|
token_id in self.all_special_ids
|
||||||
|
for token_id in top_token_ids
|
||||||
]
|
]
|
||||||
top_tokens = Tokens(
|
top_tokens = Tokens(
|
||||||
top_token_ids,
|
top_token_ids,
|
||||||
|
|
|
@ -26,6 +26,7 @@ from dataclasses import dataclass
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
from mamba_ssm.utils.generation import InferenceParams
|
from mamba_ssm.utils.generation import InferenceParams
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MambaBatch(Batch):
|
class MambaBatch(Batch):
|
||||||
batch_id: int
|
batch_id: int
|
||||||
|
@ -69,7 +70,7 @@ class MambaBatch(Batch):
|
||||||
size=len(self),
|
size=len(self),
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
|
@ -196,7 +197,7 @@ class MambaBatch(Batch):
|
||||||
new_padding_right_offset = max(
|
new_padding_right_offset = max(
|
||||||
new_padding_right_offset, remaining_decode_tokens
|
new_padding_right_offset, remaining_decode_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||||
input_ids = self.input_ids[keep_indices]
|
input_ids = self.input_ids[keep_indices]
|
||||||
|
|
||||||
|
@ -218,10 +219,13 @@ class MambaBatch(Batch):
|
||||||
self.padding_right_offset = new_padding_right_offset
|
self.padding_right_offset = new_padding_right_offset
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
|
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
|
||||||
key_value_memory_dict = {}
|
key_value_memory_dict = {}
|
||||||
for i, (conv_state, ssm_state) in self.inference_params.key_value_memory_dict.items():
|
for i, (
|
||||||
|
conv_state,
|
||||||
|
ssm_state,
|
||||||
|
) in self.inference_params.key_value_memory_dict.items():
|
||||||
key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices])
|
key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices])
|
||||||
self.inference_params.key_value_memory_dict = key_value_memory_dict
|
self.inference_params.key_value_memory_dict = key_value_memory_dict
|
||||||
|
|
||||||
|
@ -305,8 +309,9 @@ class MambaBatch(Batch):
|
||||||
|
|
||||||
start_index = end_index
|
start_index = end_index
|
||||||
|
|
||||||
|
(_, d_model, d_conv) = (
|
||||||
(_, d_model, d_conv) = batches[0].inference_params.key_value_memory_dict[0][0].shape
|
batches[0].inference_params.key_value_memory_dict[0][0].shape
|
||||||
|
)
|
||||||
(_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape
|
(_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape
|
||||||
n_blocks = len(batches[0].inference_params.key_value_memory_dict)
|
n_blocks = len(batches[0].inference_params.key_value_memory_dict)
|
||||||
dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype
|
dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype
|
||||||
|
@ -344,9 +349,15 @@ class MambaBatch(Batch):
|
||||||
for i in range(n_blocks):
|
for i in range(n_blocks):
|
||||||
conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i]
|
conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i]
|
||||||
batch_size = batch.inference_params.max_batch_size
|
batch_size = batch.inference_params.max_batch_size
|
||||||
inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state
|
inference_params.key_value_memory_dict[i][0][
|
||||||
inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state
|
current_batch : current_batch + batch_size
|
||||||
inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample
|
] = conv_state
|
||||||
|
inference_params.key_value_memory_dict[i][1][
|
||||||
|
current_batch : current_batch + batch_size
|
||||||
|
] = ssm_state
|
||||||
|
inference_params.lengths_per_sample[
|
||||||
|
current_batch : current_batch + batch_size
|
||||||
|
] = batch.inference_params.lengths_per_sample
|
||||||
current_batch += batch_size
|
current_batch += batch_size
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
|
@ -366,12 +377,13 @@ class MambaBatch(Batch):
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
inference_params=inference_params
|
inference_params=inference_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
|
|
||||||
|
|
||||||
class Mamba(Model):
|
class Mamba(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -428,7 +440,7 @@ class Mamba(Model):
|
||||||
def warmup(self, batch) -> Optional[int]:
|
def warmup(self, batch) -> Optional[int]:
|
||||||
# TODO: implement warmup for Mamba if needed
|
# TODO: implement warmup for Mamba if needed
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
@ -441,7 +453,9 @@ class Mamba(Model):
|
||||||
|
|
||||||
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
||||||
start = time.time_ns()
|
start = time.time_ns()
|
||||||
input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
|
input_ids = (
|
||||||
|
batch.input_ids
|
||||||
|
) # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
|
||||||
|
|
||||||
batch_size = input_ids.shape[0]
|
batch_size = input_ids.shape[0]
|
||||||
max_seqlen = input_ids.shape[1]
|
max_seqlen = input_ids.shape[1]
|
||||||
|
@ -450,8 +464,11 @@ class Mamba(Model):
|
||||||
# Inference params
|
# Inference params
|
||||||
seqlen_og = 0
|
seqlen_og = 0
|
||||||
inf_cache = {}
|
inf_cache = {}
|
||||||
lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen
|
lengths_per_sample = (
|
||||||
|
torch.ones(batch_size, dtype=torch.int32, device=input_ids.device)
|
||||||
|
* max_seqlen
|
||||||
|
)
|
||||||
|
|
||||||
if batch.inference_params is None:
|
if batch.inference_params is None:
|
||||||
inference_params = InferenceParams(
|
inference_params = InferenceParams(
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
|
@ -478,11 +495,16 @@ class Mamba(Model):
|
||||||
device=block.dt_proj.weight.device,
|
device=block.dt_proj.weight.device,
|
||||||
dtype=block.dt_proj.weight.dtype,
|
dtype=block.dt_proj.weight.dtype,
|
||||||
)
|
)
|
||||||
inference_params.key_value_memory_dict[block.layer_idx] = (conv_state, ssm_state)
|
inference_params.key_value_memory_dict[block.layer_idx] = (
|
||||||
|
conv_state,
|
||||||
|
ssm_state,
|
||||||
|
)
|
||||||
batch.inference_params = inference_params
|
batch.inference_params = inference_params
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
logits, past_input_ids, new_inference_params = self.model(input_ids, batch.inference_params)
|
logits, past_input_ids, new_inference_params = self.model(
|
||||||
|
input_ids, batch.inference_params
|
||||||
|
)
|
||||||
|
|
||||||
batch.inference_params = new_inference_params
|
batch.inference_params = new_inference_params
|
||||||
# Results
|
# Results
|
||||||
|
@ -564,7 +586,8 @@ class Mamba(Model):
|
||||||
prefix_offset=len(all_input_ids)
|
prefix_offset=len(all_input_ids)
|
||||||
- stopping_criteria.current_tokens
|
- stopping_criteria.current_tokens
|
||||||
- 1,
|
- 1,
|
||||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
read_offset=len(all_input_ids)
|
||||||
|
- stopping_criteria.current_tokens,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
)
|
)
|
||||||
# Get seed
|
# Get seed
|
||||||
|
|
|
@ -750,14 +750,17 @@ class Seq2SeqLM(Model):
|
||||||
|
|
||||||
if top_n_tokens > 0:
|
if top_n_tokens > 0:
|
||||||
all_top_tokens = []
|
all_top_tokens = []
|
||||||
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
|
for (top_token_ids, top_token_logprobs) in zip(
|
||||||
|
top_token_ids, top_token_logprobs
|
||||||
|
):
|
||||||
toptoken_texts = self.tokenizer.batch_decode(
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
top_token_ids,
|
top_token_ids,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
special_toptokens = [
|
special_toptokens = [
|
||||||
token_id in self.all_special_ids for token_id in top_token_ids
|
token_id in self.all_special_ids
|
||||||
|
for token_id in top_token_ids
|
||||||
]
|
]
|
||||||
top_tokens = Tokens(
|
top_tokens = Tokens(
|
||||||
top_token_ids,
|
top_token_ids,
|
||||||
|
|
|
@ -95,5 +95,7 @@ class Generation:
|
||||||
generated_text=self.generated_text.to_pb()
|
generated_text=self.generated_text.to_pb()
|
||||||
if self.generated_text is not None
|
if self.generated_text is not None
|
||||||
else None,
|
else None,
|
||||||
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens] if self.top_tokens is not None else None,
|
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens]
|
||||||
|
if self.top_tokens is not None
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
|
|
|
@ -118,6 +118,62 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||||
|
r"""
|
||||||
|
Frequency penalty as defined by OpenAI
|
||||||
|
|
||||||
|
Args:
|
||||||
|
penalty (`float`):
|
||||||
|
The parameter for frequency penalty. 0.0 means no penalty.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, penalty: float):
|
||||||
|
self.penalty = penalty
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
score = torch.gather(scores, 1, input_ids)
|
||||||
|
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
||||||
|
score = -torch.where(
|
||||||
|
score < 0, score * self.penalty, score / self.penalty
|
||||||
|
)
|
||||||
|
|
||||||
|
return scores.scatter_add_(1, input_ids, score)
|
||||||
|
|
||||||
|
|
||||||
|
class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||||
|
r"""
|
||||||
|
Frequency penalty as defined by OpenAI
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frequency_penalty (`List[float]`):
|
||||||
|
The parameter for frequency penalty. 0.0 means no penalty.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
|
||||||
|
self.penalty = penalty
|
||||||
|
self.penalty_tensor = torch.tensor(
|
||||||
|
penalty, dtype=dtype, device=device
|
||||||
|
).unsqueeze(1)
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
score = torch.gather(scores, 1, input_ids)
|
||||||
|
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
||||||
|
score = -torch.where(
|
||||||
|
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
||||||
|
)
|
||||||
|
|
||||||
|
return scores.scatter_add_(1, input_ids, score)
|
||||||
|
|
||||||
|
def filter(self, indices):
|
||||||
|
self.penalty = [self.penalty[i] for i in indices]
|
||||||
|
if any([x != 0.0 for x in self.penalty]):
|
||||||
|
self.penalty_tensor = self.penalty_tensor[indices]
|
||||||
|
return self
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousTemperatureLogitsWarper:
|
class HeterogeneousTemperatureLogitsWarper:
|
||||||
r"""
|
r"""
|
||||||
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
|
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
import re
|
import re
|
||||||
from typing import Callable, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||||
from text_generation_server.utils.logits_process import (
|
from text_generation_server.utils.logits_process import (
|
||||||
|
FrequencyPenaltyLogitsProcessor,
|
||||||
HeterogeneousProcessorWrapper,
|
HeterogeneousProcessorWrapper,
|
||||||
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
||||||
|
HeterogeneousFrequencyPenaltyLogitsProcessor,
|
||||||
HeterogeneousTemperatureLogitsWarper,
|
HeterogeneousTemperatureLogitsWarper,
|
||||||
HeterogeneousTopKLogitsWarper,
|
HeterogeneousTopKLogitsWarper,
|
||||||
HeterogeneousTopPLogitsWarper,
|
HeterogeneousTopPLogitsWarper,
|
||||||
|
@ -23,6 +25,7 @@ class NextTokenChooser:
|
||||||
watermark=False,
|
watermark=False,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
repetition_penalty=1.0,
|
repetition_penalty=1.0,
|
||||||
|
frequency_penalty=0.0,
|
||||||
top_k=None,
|
top_k=None,
|
||||||
top_p=None,
|
top_p=None,
|
||||||
typical_p=None,
|
typical_p=None,
|
||||||
|
@ -35,7 +38,12 @@ class NextTokenChooser:
|
||||||
)
|
)
|
||||||
self.repetition_processor = (
|
self.repetition_processor = (
|
||||||
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
|
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
|
||||||
if repetition_penalty
|
if repetition_penalty and repetition_penalty != 1.0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.frequency_processor = (
|
||||||
|
FrequencyPenaltyLogitsProcessor(penalty=frequency_penalty)
|
||||||
|
if frequency_penalty and frequency_penalty != 0.0
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -60,6 +68,8 @@ class NextTokenChooser:
|
||||||
scores = self.watermark_processor(input_ids, scores)
|
scores = self.watermark_processor(input_ids, scores)
|
||||||
if self.repetition_processor is not None:
|
if self.repetition_processor is not None:
|
||||||
scores = self.repetition_processor(input_ids, scores)
|
scores = self.repetition_processor(input_ids, scores)
|
||||||
|
if self.frequency_processor is not None:
|
||||||
|
scores = self.frequency_processor(input_ids, scores)
|
||||||
|
|
||||||
if self.static_warper is None:
|
if self.static_warper is None:
|
||||||
next_logprob = torch.log_softmax(scores, -1)
|
next_logprob = torch.log_softmax(scores, -1)
|
||||||
|
@ -80,6 +90,7 @@ class NextTokenChooser:
|
||||||
watermark=pb.watermark,
|
watermark=pb.watermark,
|
||||||
temperature=pb.temperature,
|
temperature=pb.temperature,
|
||||||
repetition_penalty=pb.repetition_penalty,
|
repetition_penalty=pb.repetition_penalty,
|
||||||
|
frequency_penalty=pb.frequency_penalty,
|
||||||
top_k=pb.top_k,
|
top_k=pb.top_k,
|
||||||
top_p=pb.top_p,
|
top_p=pb.top_p,
|
||||||
typical_p=pb.typical_p,
|
typical_p=pb.typical_p,
|
||||||
|
@ -184,6 +195,7 @@ class HeterogeneousNextTokenChooser:
|
||||||
watermark: List[bool],
|
watermark: List[bool],
|
||||||
temperature: List[float],
|
temperature: List[float],
|
||||||
repetition_penalty: List[float],
|
repetition_penalty: List[float],
|
||||||
|
frequency_penalty: List[float],
|
||||||
top_k: List[int],
|
top_k: List[int],
|
||||||
top_p: List[float],
|
top_p: List[float],
|
||||||
typical_p: List[float],
|
typical_p: List[float],
|
||||||
|
@ -212,6 +224,14 @@ class HeterogeneousNextTokenChooser:
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.frequency_processor = (
|
||||||
|
HeterogeneousFrequencyPenaltyLogitsProcessor(
|
||||||
|
frequency_penalty, dtype, device
|
||||||
|
)
|
||||||
|
if any([x != 0.0 for x in frequency_penalty])
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
if any([x != 1.0 for x in temperature]):
|
if any([x != 1.0 for x in temperature]):
|
||||||
do_sample = [
|
do_sample = [
|
||||||
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||||
|
@ -269,6 +289,8 @@ class HeterogeneousNextTokenChooser:
|
||||||
_scores = self.watermark_processor(input_ids, _scores)
|
_scores = self.watermark_processor(input_ids, _scores)
|
||||||
if self.repetition_processor is not None:
|
if self.repetition_processor is not None:
|
||||||
_scores = self.repetition_processor(input_ids, _scores)
|
_scores = self.repetition_processor(input_ids, _scores)
|
||||||
|
if self.frequency_processor is not None:
|
||||||
|
_scores = self.frequency_processor(input_ids, _scores)
|
||||||
|
|
||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
_scores = warper(input_ids, _scores)
|
_scores = warper(input_ids, _scores)
|
||||||
|
@ -316,7 +338,6 @@ class HeterogeneousNextTokenChooser:
|
||||||
|
|
||||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||||
|
|
||||||
|
|
||||||
if speculate > 0:
|
if speculate > 0:
|
||||||
if speculative_scores is not None:
|
if speculative_scores is not None:
|
||||||
# Medusa provided some scores
|
# Medusa provided some scores
|
||||||
|
@ -338,6 +359,9 @@ class HeterogeneousNextTokenChooser:
|
||||||
if self.repetition_processor is not None:
|
if self.repetition_processor is not None:
|
||||||
self.repetition_processor = self.repetition_processor.filter(indices)
|
self.repetition_processor = self.repetition_processor.filter(indices)
|
||||||
|
|
||||||
|
if self.frequency_processor is not None:
|
||||||
|
self.frequency_processor = self.frequency_processor.filter(indices)
|
||||||
|
|
||||||
filtered_warpers = []
|
filtered_warpers = []
|
||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
filtered_warper = warper.filter(indices)
|
filtered_warper = warper.filter(indices)
|
||||||
|
@ -366,6 +390,7 @@ class HeterogeneousNextTokenChooser:
|
||||||
watermark=[pb_.watermark for pb_ in pb],
|
watermark=[pb_.watermark for pb_ in pb],
|
||||||
temperature=[pb_.temperature for pb_ in pb],
|
temperature=[pb_.temperature for pb_ in pb],
|
||||||
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
|
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
|
||||||
|
frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
|
||||||
top_k=[pb_.top_k for pb_ in pb],
|
top_k=[pb_.top_k for pb_ in pb],
|
||||||
top_p=[pb_.top_p for pb_ in pb],
|
top_p=[pb_.top_p for pb_ in pb],
|
||||||
typical_p=[pb_.typical_p for pb_ in pb],
|
typical_p=[pb_.typical_p for pb_ in pb],
|
||||||
|
@ -438,7 +463,10 @@ class HeterogeneousSampling:
|
||||||
|
|
||||||
|
|
||||||
def batch_top_tokens(
|
def batch_top_tokens(
|
||||||
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor
|
top_n_tokens: List[int],
|
||||||
|
top_n_tokens_tensor: torch.Tensor,
|
||||||
|
logprobs: torch.Tensor,
|
||||||
|
accepted_ids: torch.Tensor,
|
||||||
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
|
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
|
||||||
"""Find the top n most likely tokens for a batch of generations.
|
"""Find the top n most likely tokens for a batch of generations.
|
||||||
|
|
||||||
|
@ -450,12 +478,15 @@ def batch_top_tokens(
|
||||||
if max_top_n == 0:
|
if max_top_n == 0:
|
||||||
return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
|
return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
|
||||||
|
|
||||||
|
|
||||||
batch_size = accepted_ids.shape[0]
|
batch_size = accepted_ids.shape[0]
|
||||||
speculate_size = logprobs.shape[0] // batch_size
|
speculate_size = logprobs.shape[0] // batch_size
|
||||||
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
|
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
|
||||||
# Ensure top_n doesn't exceed vocab size
|
# Ensure top_n doesn't exceed vocab size
|
||||||
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculate_size)]
|
top_n_tokens = [
|
||||||
|
min(tok, logprobs.size(-1))
|
||||||
|
for tok in top_n_tokens
|
||||||
|
for _ in range(speculate_size)
|
||||||
|
]
|
||||||
|
|
||||||
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
||||||
# Sorted topk is faster than torch.sort() since we only need a small subset
|
# Sorted topk is faster than torch.sort() since we only need a small subset
|
||||||
|
@ -484,10 +515,10 @@ def batch_top_tokens(
|
||||||
for i, n_accepted_ids in enumerate(accepted_ids_list):
|
for i, n_accepted_ids in enumerate(accepted_ids_list):
|
||||||
start = speculate_size * i
|
start = speculate_size * i
|
||||||
stop = speculate_size * (i + 1)
|
stop = speculate_size * (i + 1)
|
||||||
_top_indices = top_indices[start: stop]
|
_top_indices = top_indices[start:stop]
|
||||||
_top_values = top_values[start: stop]
|
_top_values = top_values[start:stop]
|
||||||
_top_n_ishes = top_n_ishes[start: stop]
|
_top_n_ishes = top_n_ishes[start:stop]
|
||||||
_top_n_tokens = top_n_tokens[start: stop]
|
_top_n_tokens = top_n_tokens[start:stop]
|
||||||
|
|
||||||
_top_indices = _top_indices[:n_accepted_ids]
|
_top_indices = _top_indices[:n_accepted_ids]
|
||||||
_top_values = _top_values[:n_accepted_ids]
|
_top_values = _top_values[:n_accepted_ids]
|
||||||
|
@ -497,7 +528,9 @@ def batch_top_tokens(
|
||||||
row_top_token_ids = []
|
row_top_token_ids = []
|
||||||
row_top_token_logprobs = []
|
row_top_token_logprobs = []
|
||||||
|
|
||||||
for idxs, vals, n, req_n in zip(_top_indices, _top_values, _top_n_ishes, _top_n_tokens):
|
for idxs, vals, n, req_n in zip(
|
||||||
|
_top_indices, _top_values, _top_n_ishes, _top_n_tokens
|
||||||
|
):
|
||||||
indices = idxs[:n] if req_n > 0 else []
|
indices = idxs[:n] if req_n > 0 else []
|
||||||
values = vals[:n] if req_n > 0 else []
|
values = vals[:n] if req_n > 0 else []
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue