# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Vincent Brouwers <vincent.brouwers@ing.com>
This commit is contained in:
Nicolas Patry 2023-08-28 11:43:47 +02:00 committed by GitHub
parent 4486f78cf9
commit 211b54ac41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 529 additions and 34 deletions

View File

@ -37,6 +37,7 @@ pub(crate) async fn generation_task(
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
parameters: NextTokenChooserParameters,
@ -48,7 +49,7 @@ pub(crate) async fn generation_task(
// End task if a message is received on shutdown_receiver
// _shutdown_guard_sender will be dropped once the task is finished
tokio::select! {
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, parameters, client, run_sender.clone()) => {
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => {
if let Err(err) = res {
run_sender.send(Err(err)).await.unwrap_or(());
}
@ -64,6 +65,7 @@ async fn generate_runs(
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
parameters: NextTokenChooserParameters,
@ -82,6 +84,7 @@ async fn generate_runs(
b,
decode_length,
parameters.clone(),
top_n_tokens,
&mut client,
)
.await?;
@ -97,6 +100,7 @@ async fn generate_runs(
b,
decode_length,
parameters.clone(),
top_n_tokens,
&mut client,
)
.await?;
@ -130,6 +134,7 @@ async fn prefill(
batch_size: u32,
decode_length: u32,
parameters: NextTokenChooserParameters,
top_n_tokens: Option<u32>,
client: &mut ShardedClient,
) -> Result<(Prefill, CachedBatch), ClientError> {
// Create requests
@ -145,6 +150,7 @@ async fn prefill(
stop_sequences: vec![],
ignore_eos_token: true, // Will not stop even if a eos token is generated
}),
top_n_tokens: top_n_tokens.unwrap_or(0),
})
.collect();

View File

@ -22,6 +22,7 @@ pub async fn run(
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
temperature: Option<f32>,
@ -70,6 +71,7 @@ pub async fn run(
batch_size.clone(),
sequence_length,
decode_length,
top_n_tokens,
n_runs,
warmups,
parameters,
@ -130,6 +132,7 @@ pub async fn run(
tokenizer_name,
sequence_length,
decode_length,
top_n_tokens,
n_runs,
warmups,
temperature,

View File

@ -93,6 +93,11 @@ struct Args {
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
do_sample: bool,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
top_n_tokens: Option<u32>,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
@ -117,6 +122,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
watermark,
do_sample,
master_shard_uds_path,
top_n_tokens,
} = args;
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
@ -173,6 +179,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
batch_size,
sequence_length,
decode_length,
top_n_tokens,
runs,
warmups,
temperature,

View File

@ -7,6 +7,7 @@ pub(crate) fn parameters_table(
tokenizer_name: String,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
temperature: Option<f32>,
@ -24,6 +25,7 @@ pub(crate) fn parameters_table(
builder.push_record(["Model", &tokenizer_name]);
builder.push_record(["Sequence Length", &sequence_length.to_string()]);
builder.push_record(["Decode Length", &decode_length.to_string()]);
builder.push_record(["Top N Tokens", &format!("{top_n_tokens:?}")]);
builder.push_record(["N Runs", &n_runs.to_string()]);
builder.push_record(["Warmups", &warmups.to_string()]);
builder.push_record(["Temperature", &format!("{temperature:?}")]);

View File

@ -75,6 +75,7 @@ class Client:
typical_p: Optional[float] = None,
watermark: bool = False,
decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None,
) -> Response:
"""
Given a prompt, generate the following text
@ -113,6 +114,8 @@ class Client:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns:
Response: generated response
@ -134,6 +137,7 @@ class Client:
typical_p=typical_p,
watermark=watermark,
decoder_input_details=decoder_input_details,
top_n_tokens=top_n_tokens
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -164,6 +168,7 @@ class Client:
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
top_n_tokens: Optional[int] = None,
) -> Iterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens
@ -198,6 +203,8 @@ class Client:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns:
Iterator[StreamResponse]: stream of generated tokens
@ -219,6 +226,7 @@ class Client:
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
top_n_tokens=top_n_tokens,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
@ -317,6 +325,7 @@ class AsyncClient:
typical_p: Optional[float] = None,
watermark: bool = False,
decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None,
) -> Response:
"""
Given a prompt, generate the following text asynchronously
@ -355,6 +364,8 @@ class AsyncClient:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns:
Response: generated response
@ -376,6 +387,7 @@ class AsyncClient:
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
top_n_tokens=top_n_tokens,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -404,6 +416,7 @@ class AsyncClient:
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
top_n_tokens: Optional[int] = None,
) -> AsyncIterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens asynchronously
@ -438,6 +451,8 @@ class AsyncClient:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns:
AsyncIterator[StreamResponse]: stream of generated tokens
@ -459,6 +474,7 @@ class AsyncClient:
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
top_n_tokens=top_n_tokens,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)

View File

@ -39,6 +39,8 @@ class Parameters(BaseModel):
details: bool = False
# Get decoder input token logprobs and ids
decoder_input_details: bool = False
# Return the N most likely tokens at each step
top_n_tokens: Optional[int]
@validator("best_of")
def valid_best_of(cls, field_value, values):
@ -101,6 +103,12 @@ class Parameters(BaseModel):
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
return v
@validator("top_n_tokens")
def valid_top_n_tokens(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_n_tokens` must be strictly positive")
return v
class Request(BaseModel):
# Prompt
@ -125,9 +133,7 @@ class Request(BaseModel):
and parameters.best_of > 1
and field_value
):
raise ValidationError(
"`best_of` != 1 is not supported when `stream` == True"
)
raise ValidationError("`best_of` != 1 is not supported when `stream` == True")
return field_value
@ -179,6 +185,8 @@ class BestOfSequence(BaseModel):
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
# Most likely tokens
top_tokens: Optional[List[List[Token]]]
# `generate` details
@ -193,6 +201,8 @@ class Details(BaseModel):
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
# Most likely tokens
top_tokens: Optional[List[List[Token]]]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]
@ -219,6 +229,8 @@ class StreamDetails(BaseModel):
class StreamResponse(BaseModel):
# Generated token
token: Token
# Most likely tokens
top_tokens: Optional[List[Token]]
# Complete generated text
# Only available when the generation is finished
generated_text: Optional[str]

View File

@ -159,6 +159,14 @@ struct Args {
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
/// This is the maximum allowed value for clients to set `top_n_tokens`.
/// `top_n_tokens is used to return information about the the `n` most likely
/// tokens at each generation step, instead of just the sampled token. This
/// information can be used for downstream tasks like for classification or
/// ranking.
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
/// This is the maximum allowed input length (expressed in number of tokens)
/// for users. The larger this value, the longer prompt users can send which
/// can impact the overall memory required to handle the load.
@ -929,6 +937,8 @@ fn spawn_webserver(
args.max_best_of.to_string(),
"--max-stop-sequences".to_string(),
args.max_stop_sequences.to_string(),
"--max-top-n-tokens".to_string(),
args.max_top_n_tokens.to_string(),
"--max-input-length".to_string(),
args.max_input_length.to_string(),
"--max-total-tokens".to_string(),

View File

@ -91,6 +91,8 @@ message Request {
StoppingCriteriaParameters stopping_parameters = 5;
/// Return prefill logprobs
bool prefill_logprobs = 6;
/// Return most likely n tokens
uint32 top_n_tokens = 7;
}
message Batch {
@ -141,6 +143,17 @@ message PrefillTokens {
repeated string texts = 3;
}
message TopTokens {
/// Top Token IDs
repeated uint32 ids = 1;
/// Top Logprobs
repeated float logprobs = 2;
/// Top Token Texts
repeated string texts = 3;
/// If the tokens are special
repeated bool is_special = 6;
}
message Generation {
/// Request ID
uint64 request_id = 1;
@ -156,6 +169,8 @@ message Generation {
bool token_is_special = 6;
/// Complete generated text
optional GeneratedText generated_text = 7;
/// Top tokens
TopTokens top_tokens = 8;
}
message FilterBatchRequest {

View File

@ -131,6 +131,7 @@ impl Client {
ignore_eos_token: false,
}),
prefill_logprobs: true,
top_n_tokens: 20,
});
n_tokens += max_input_length;
}

View File

@ -50,6 +50,7 @@ impl Health {
stop_sequences: vec![],
ignore_eos_token: false,
}),
top_n_tokens: 0,
};
let batch = Batch {
id: BATCH_ID,

View File

@ -138,12 +138,15 @@ impl Infer {
&self,
request: GenerateRequest,
) -> Result<InferResponse, InferError> {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
// Create stream and keep semaphore permit as long as generate lives
let (_permit, mut stream) = self.generate_stream(request).await?;
// Return values
let mut result_prefill = Vec::new();
let mut result_tokens = Vec::new();
let mut result_top_tokens = Vec::new();
let mut result_generated_text = None;
let mut result_start = None;
let mut result_queued = None;
@ -164,7 +167,10 @@ impl Infer {
.collect();
}
// Push last token
InferStreamResponse::Token(token) => result_tokens.push(token),
InferStreamResponse::Intermediate { token, top_tokens } => {
result_tokens.push(token);
result_top_tokens.push(top_tokens);
}
// Final message
// Set return values
InferStreamResponse::End {
@ -172,8 +178,10 @@ impl Infer {
generated_text,
start,
queued,
top_tokens,
} => {
result_tokens.push(token);
result_top_tokens.push(top_tokens);
result_generated_text = Some(generated_text);
result_start = Some(start);
result_queued = Some(queued)
@ -191,6 +199,11 @@ impl Infer {
generated_text,
queued,
start,
top_tokens: if use_top_tokens {
result_top_tokens
} else {
Vec::new()
},
})
} else {
let err = InferError::IncompleteGeneration;
@ -520,6 +533,26 @@ fn send_responses(
special: generation.token_is_special,
};
// generation.top_tokens
let mut top_tokens = Vec::new();
if let Some(top_tokens_) = generation.top_tokens {
top_tokens.extend(
top_tokens_
.ids
.into_iter()
.zip(top_tokens_.logprobs.into_iter())
.zip(top_tokens_.texts.into_iter())
.zip(top_tokens_.is_special.into_iter())
.map(|(((id, logprob), text), special)| Token {
id,
text,
logprob,
special,
}),
)
}
if let Some(generated_text) = generation.generated_text {
// Generation has ended
stopped = true;
@ -527,6 +560,7 @@ fn send_responses(
entry.response_tx.send_timeout(
Ok(InferStreamResponse::End {
token,
top_tokens,
generated_text,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
@ -536,7 +570,7 @@ fn send_responses(
} else {
// Send message
entry.response_tx.send_timeout(
Ok(InferStreamResponse::Token(token)),
Ok(InferStreamResponse::Intermediate { token, top_tokens }),
Duration::from_millis(10),
)?;
}
@ -566,10 +600,14 @@ pub(crate) enum InferStreamResponse {
// Optional first message
Prefill(PrefillTokens),
// Intermediate messages
Token(Token),
Intermediate {
token: Token,
top_tokens: Vec<Token>,
},
// Last message
End {
token: Token,
top_tokens: Vec<Token>,
generated_text: GeneratedText,
start: Instant,
queued: Instant,
@ -583,6 +621,7 @@ pub(crate) struct InferResponse {
pub(crate) generated_text: GeneratedText,
pub(crate) queued: Instant,
pub(crate) start: Instant,
pub(crate) top_tokens: Vec<Vec<Token>>,
}
#[derive(Debug, Error)]

View File

@ -135,6 +135,9 @@ pub(crate) struct GenerateParameters {
example = "null"
)]
pub seed: Option<u64>,
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>,
}
fn default_max_new_tokens() -> u32 {
@ -158,6 +161,7 @@ fn default_parameters() -> GenerateParameters {
details: false,
decoder_input_details: false,
seed: None,
top_n_tokens: None,
}
}
@ -235,6 +239,8 @@ pub(crate) struct BestOfSequence {
pub seed: Option<u64>,
pub prefill: Vec<PrefillToken>,
pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>,
}
#[derive(Serialize, ToSchema)]
@ -249,6 +255,8 @@ pub(crate) struct Details {
pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of_sequences: Option<Vec<BestOfSequence>>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>,
}
#[derive(Serialize, ToSchema)]
@ -272,6 +280,8 @@ pub(crate) struct StreamDetails {
#[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse {
pub token: Token,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Token>,
#[schema(nullable = true, default = "null", example = "test")]
pub generated_text: Option<String>,
#[schema(nullable = true, default = "null")]

View File

@ -29,6 +29,8 @@ struct Args {
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_length: usize,
#[clap(default_value = "2048", long, env)]
@ -75,6 +77,7 @@ fn main() -> Result<(), RouterError> {
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_length,
max_total_tokens,
waiting_served_ratio,
@ -259,6 +262,7 @@ fn main() -> Result<(), RouterError> {
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_length,
max_total_tokens,
waiting_served_ratio,

View File

@ -235,6 +235,7 @@ impl State {
truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
top_n_tokens: entry.request.top_n_tokens,
});
// Set batch_time
entry.batch_time = Some(Instant::now());
@ -328,6 +329,7 @@ mod tests {
max_new_tokens: 1,
stop_sequences: vec![],
},
top_n_tokens: 0,
},
response_tx,
span: info_span!("entry"),

View File

@ -158,7 +158,7 @@ async fn generate(
add_prompt = Some(req.inputs.clone());
}
let details = req.parameters.details || req.parameters.decoder_input_details;
let details: bool = req.parameters.details || req.parameters.decoder_input_details;
// Inference
let (response, best_of_responses) = match req.parameters.best_of {
@ -191,6 +191,7 @@ async fn generate(
generated_tokens: response.generated_text.generated_tokens,
prefill: response.prefill,
tokens: response.tokens,
top_tokens: response.top_tokens,
seed: response.generated_text.seed,
}
})
@ -204,6 +205,7 @@ async fn generate(
tokens: response.tokens,
seed: response.generated_text.seed,
best_of_sequences,
top_tokens: response.top_tokens,
})
}
false => None,
@ -385,12 +387,16 @@ async fn generate_stream(
// Prefill is ignored
InferStreamResponse::Prefill(_) => {}
// Yield event for every new token
InferStreamResponse::Token(token) => {
InferStreamResponse::Intermediate{
token,
top_tokens,
} => {
tracing::debug!(parent: &span, "Token: {:?}", token);
// StreamResponse
let stream_token = StreamResponse {
token,
top_tokens: top_tokens,
generated_text: None,
details: None,
};
@ -403,6 +409,7 @@ async fn generate_stream(
generated_text,
start,
queued,
top_tokens,
} => {
// Token details
let details = match details {
@ -451,6 +458,7 @@ async fn generate_stream(
let stream_token = StreamResponse {
token,
top_tokens: top_tokens,
generated_text: Some(output_text),
details
};
@ -509,6 +517,7 @@ pub async fn run(
max_concurrent_requests: usize,
max_best_of: usize,
max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_length: usize,
max_total_tokens: usize,
waiting_served_ratio: f32,
@ -571,6 +580,7 @@ pub async fn run(
tokenizer,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_length,
max_total_tokens,
);

View File

@ -15,6 +15,7 @@ pub struct Validation {
/// Validation parameters
max_best_of: usize,
max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_length: usize,
max_total_tokens: usize,
/// Channel to communicate with the background tokenization task
@ -27,6 +28,7 @@ impl Validation {
tokenizer: Option<Tokenizer>,
max_best_of: usize,
max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_length: usize,
max_total_tokens: usize,
) -> Self {
@ -54,6 +56,7 @@ impl Validation {
max_best_of,
sender,
max_stop_sequences,
max_top_n_tokens,
max_input_length,
max_total_tokens,
}
@ -142,6 +145,7 @@ impl Validation {
seed,
watermark,
decoder_input_details,
top_n_tokens,
..
} = request.parameters;
@ -218,6 +222,15 @@ impl Validation {
}
};
let top_n_tokens = top_n_tokens
.map(|value| {
if value > self.max_top_n_tokens {
return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
}
Ok(value)
})
.unwrap_or(Ok(0))?;
// Check if inputs is empty
if request.inputs.is_empty() {
return Err(EmptyInput);
@ -263,6 +276,7 @@ impl Validation {
truncate: truncate.unwrap_or(self.max_input_length) as u32,
parameters,
stopping_parameters,
top_n_tokens: top_n_tokens,
})
}
@ -336,6 +350,7 @@ pub(crate) struct ValidGenerateRequest {
pub decoder_input_details: bool,
pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters,
pub top_n_tokens: u32,
}
#[derive(Error, Debug)]
@ -350,6 +365,10 @@ pub enum ValidationError {
BestOfSeed,
#[error("`best_of` != 1 is not supported when streaming tokens")]
BestOfStream,
#[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
TopNTokens(u32, u32),
#[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
TopNTokensDisabled,
#[error("`decoder_input_details` == true is not supported when streaming tokens")]
PrefillDetailsStream,
#[error("`temperature` must be strictly positive")]
@ -391,14 +410,16 @@ mod tests {
let tokenizer = None;
let max_best_of = 2;
let max_stop_sequence = 3;
let max_input_length = 4;
let max_total_tokens = 5;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1;
let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequence,
max_top_n_tokens,
max_input_length,
max_total_tokens,
);
@ -418,14 +439,16 @@ mod tests {
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
let max_stop_sequence = 3;
let max_input_length = 4;
let max_total_tokens = 5;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1;
let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequence,
max_top_n_tokens,
max_input_length,
max_total_tokens,
);
@ -435,7 +458,7 @@ mod tests {
.validate_input("Hello".to_string(), None, max_new_tokens)
.await
{
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
_ => panic!("Unexpected not max new tokens"),
}
}
@ -445,14 +468,16 @@ mod tests {
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
let max_stop_sequence = 3;
let max_input_length = 4;
let max_total_tokens = 5;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1;
let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequence,
max_top_n_tokens,
max_input_length,
max_total_tokens,
);
@ -477,14 +502,16 @@ mod tests {
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
let max_stop_sequence = 3;
let max_input_length = 4;
let max_total_tokens = 5;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1;
let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequence,
max_top_n_tokens,
max_input_length,
max_total_tokens,
);
@ -531,4 +558,75 @@ mod tests {
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
assert_eq!(valid_request.parameters.top_p, 1.0);
}
#[tokio::test]
async fn test_validation_top_n_tokens() {
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
let max_stop_sequences = 3;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1;
let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_length,
max_total_tokens,
);
match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: Some(5),
..default_parameters()
},
})
.await
{
Err(ValidationError::TopNTokens(4, 5)) => (),
_ => panic!("Unexpected top_n_tokens"),
}
validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: Some(4),
max_new_tokens: 1,
..default_parameters()
},
})
.await
.unwrap();
validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: Some(0),
max_new_tokens: 1,
..default_parameters()
},
})
.await
.unwrap();
let valid_request = validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: None,
max_new_tokens: 1,
..default_parameters()
},
})
.await
.unwrap();
assert_eq!(valid_request.top_n_tokens, 0);
}
}

View File

@ -1,7 +1,9 @@
import torch
from text_generation_server.utils.tokens import (
StopSequenceCriteria,
StoppingCriteria,
FinishReason,
batch_top_tokens,
)
@ -42,3 +44,22 @@ def test_stopping_criteria_max():
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
def test_batch_top_tokens():
top_n_tokens = [0, 2, 3, 4, 5]
top_n_tokens_tensor = torch.tensor(top_n_tokens)
inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5)
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, top_n_tokens_tensor, inp_logprobs)
assert topn_tok_ids[0] == []
assert topn_tok_ids[1] == [0, 3]
assert topn_tok_ids[2] == [0, 3, 1, 4]
assert topn_tok_ids[3] == [0, 3, 1, 4]
assert topn_tok_ids[4] == [0, 3, 1, 4, 2]
assert topn_tok_logprobs[0] == []
assert topn_tok_logprobs[1] == [-1, -2]
assert topn_tok_logprobs[2] == [-1, -2, -3, -3]
assert topn_tok_logprobs[3] == [-1, -2, -3, -3]
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4]

View File

@ -1,3 +1,4 @@
from text_generation_server.utils.tokens import batch_top_tokens
import torch
import inspect
@ -12,6 +13,7 @@ from text_generation_server.models.types import (
PrefillTokens,
Generation,
GeneratedText,
TopTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -42,6 +44,8 @@ class CausalLMBatch(Batch):
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding
max_input_length: int
@ -72,6 +76,7 @@ class CausalLMBatch(Batch):
inputs = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
prefix_offsets = []
read_offsets = []
requests_idx_mapping = {}
@ -88,6 +93,7 @@ class CausalLMBatch(Batch):
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
@ -121,6 +127,9 @@ class CausalLMBatch(Batch):
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
@ -138,6 +147,8 @@ class CausalLMBatch(Batch):
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
@ -163,6 +174,7 @@ class CausalLMBatch(Batch):
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
total_remaining_decode_tokens = 0
new_padding_right_offset = 0
@ -184,6 +196,7 @@ class CausalLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
@ -223,6 +236,7 @@ class CausalLMBatch(Batch):
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
self.requests = requests
@ -235,6 +249,8 @@ class CausalLMBatch(Batch):
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
@ -262,6 +278,7 @@ class CausalLMBatch(Batch):
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
max_tokens = 0
# Batch tensors
@ -269,6 +286,7 @@ class CausalLMBatch(Batch):
attention_mask = None
position_ids = None
past_key_values = []
top_n_tokens_tensor = None
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
@ -281,6 +299,7 @@ class CausalLMBatch(Batch):
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
@ -310,6 +329,12 @@ class CausalLMBatch(Batch):
(total_batch_size, max_input_length + padding_right_offset),
)
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space
left_offset = max_input_length - batch.max_input_length
@ -438,6 +463,8 @@ class CausalLMBatch(Batch):
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
@ -549,6 +576,12 @@ class CausalLM(Model):
generations: List[Generation] = []
stopped = True
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.softmax(logits[:, -1], -1),
)
# Zipped iterator
iterator = zip(
batch.requests,
@ -559,6 +592,9 @@ class CausalLM(Model):
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
)
# For each member of the batch
@ -571,6 +607,9 @@ class CausalLM(Model):
next_token_chooser,
stopping_criteria,
all_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
@ -637,6 +676,24 @@ class CausalLM(Model):
else:
prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,
@ -645,6 +702,7 @@ class CausalLM(Model):
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
top_tokens,
)
generations.append(generation)

View File

@ -1,5 +1,6 @@
import math
import itertools
from text_generation_server.utils.tokens import batch_top_tokens
import torch
import torch.distributed
@ -16,6 +17,7 @@ from text_generation_server.models.types import (
PrefillTokens,
Generation,
GeneratedText,
TopTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
@ -165,6 +167,8 @@ class FlashCausalLMBatch(Batch):
# Generation helpers
next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Number of blocks in this batch
blocks: int
@ -217,6 +221,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters = []
stopping_criterias = []
top_n_tokens = []
# Cumulative length
cumulative_length = 0
@ -259,6 +264,7 @@ class FlashCausalLMBatch(Batch):
)
max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
# Paged attention
# Remove one as the first token des not have a past
@ -352,6 +358,9 @@ class FlashCausalLMBatch(Batch):
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
return cls(
batch_id=pb.id,
@ -378,6 +387,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
)
@ -417,6 +428,7 @@ class FlashCausalLMBatch(Batch):
read_offsets = []
stopping_criterias = []
top_n_tokens = []
blocks = 0
max_blocks = 0
@ -443,6 +455,8 @@ class FlashCausalLMBatch(Batch):
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
@ -487,6 +501,7 @@ class FlashCausalLMBatch(Batch):
input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
start_slots = torch.tensor(start_slots, dtype=torch.int64)
@ -518,6 +533,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
)
@ -566,6 +583,9 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length)
)
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
start_slots = []
block_tables = []
@ -577,6 +597,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters = []
stopping_criterias = []
top_n_tokens = []
# Cumulative length
cumulative_batch_size = 0
@ -602,6 +623,7 @@ class FlashCausalLMBatch(Batch):
position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
slots[slots_start_index:slots_end_index] = batch.slots
all_input_ids_tensor[
@ -624,6 +646,8 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
# Update
cumulative_batch_size += len(batch)
cumulative_slots += len(batch.slots)
@ -666,6 +690,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
)
@ -831,10 +857,14 @@ class FlashCausalLM(Model):
else:
next_token_logits = out
next_input_ids, next_token_logprobs = batch.next_token_chooser(
next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
)
if prefill:
if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
@ -931,8 +961,11 @@ class FlashCausalLM(Model):
batch.all_input_ids,
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
batch.top_n_tokens,
next_token_ids,
next_token_logprobs,
batch_top_token_ids,
batch_top_token_logprobs,
)
# For each member of the batch
@ -945,8 +978,11 @@ class FlashCausalLM(Model):
all_input_ids,
do_sample,
seed,
top_n_tokens,
next_token_id,
next_token_logprob,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Append next token to all tokens
all_input_ids.append(next_token_id)
@ -1005,6 +1041,24 @@ class FlashCausalLM(Model):
else:
prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,
@ -1013,6 +1067,7 @@ class FlashCausalLM(Model):
next_token_text,
next_token_id in self.all_special_ids,
generated_text,
top_tokens,
)
generations.append(generation)

View File

@ -763,6 +763,8 @@ class IdeficsCausalLM(Model):
else:
prefill_tokens = None
top_tokens=None
generation = Generation(
request.id,
prefill_tokens,
@ -771,6 +773,7 @@ class IdeficsCausalLM(Model):
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
top_tokens
)
generations.append(generation)

View File

@ -1,3 +1,4 @@
from text_generation_server.utils.tokens import batch_top_tokens
import torch
from dataclasses import dataclass
@ -11,6 +12,7 @@ from text_generation_server.models.types import (
Batch,
Generation,
PrefillTokens,
TopTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -48,6 +50,8 @@ class Seq2SeqLMBatch(Batch):
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding
max_input_length: int
@ -78,7 +82,7 @@ class Seq2SeqLMBatch(Batch):
inputs = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
decoder_input_lengths = []
prefix_offsets = []
read_offsets = []
@ -97,6 +101,7 @@ class Seq2SeqLMBatch(Batch):
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
@ -126,6 +131,9 @@ class Seq2SeqLMBatch(Batch):
prefix_offsets.append(0)
read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
@ -146,6 +154,8 @@ class Seq2SeqLMBatch(Batch):
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(),
max_decoder_input_length=1,
padding_right_offset=padding_right_offset,
@ -173,6 +183,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
max_input_length = 0
max_decoder_input_length = 0
@ -204,6 +215,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
@ -239,6 +251,7 @@ class Seq2SeqLMBatch(Batch):
layer[2] = layer[2][keep_indices, :, -max_input_length:]
layer[3] = layer[3][keep_indices, :, -max_input_length:]
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = (
len(request_ids) * (max_input_length + max_decoder_input_length)
+ remaining_decode_tokens
@ -254,6 +267,8 @@ class Seq2SeqLMBatch(Batch):
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length
self.max_decoder_input_length = max_decoder_input_length
self.padding_right_offset = padding_right_offset
@ -289,6 +304,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
max_tokens = 0
# Batch tensors
@ -296,6 +312,7 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids = None
decoder_attention_mask = None
encoder_last_hidden_state = None
top_n_tokens_tensor = None
past_key_values = []
# Used for slicing correctly inside the tensors
@ -312,6 +329,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
@ -384,6 +402,12 @@ class Seq2SeqLMBatch(Batch):
),
)
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# Copy to correct indices
encoder_last_hidden_state[
start_index:end_index, -batch.max_input_length :, :
@ -488,6 +512,8 @@ class Seq2SeqLMBatch(Batch):
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset,
@ -613,6 +639,12 @@ class Seq2SeqLM(Model):
batch.past_key_values,
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.softmax(logits[:, -1], -1),
)
# Finished requests
generations: List[Generation] = []
stopped = True
@ -628,6 +660,9 @@ class Seq2SeqLM(Model):
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_decoder_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
)
# For each member of the batch
@ -641,6 +676,9 @@ class Seq2SeqLM(Model):
next_token_chooser,
stopping_criteria,
all_decoder_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
@ -698,6 +736,24 @@ class Seq2SeqLM(Model):
else:
prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,
@ -706,6 +762,7 @@ class Seq2SeqLM(Model):
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
top_tokens,
)
generations.append(generation)

View File

@ -1,3 +1,4 @@
from functools import total_ordering
import torch
from abc import ABC, abstractmethod
@ -71,6 +72,25 @@ class PrefillTokens:
return len(self.token_ids)
@dataclass
class TopTokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
is_special: List[bool]
def to_pb(self) -> generate_pb2.TopTokens:
return generate_pb2.TopTokens(
ids=self.token_ids,
logprobs=self.logprobs,
texts=self.texts,
is_special=self.is_special,
)
def __len__(self):
return len(self.token_ids)
@dataclass
class Generation:
request_id: int
@ -80,6 +100,8 @@ class Generation:
token_text: str
token_is_special: bool
generated_text: Optional[GeneratedText]
# Optional for now, since it's not yet supported for every model.
top_tokens: Optional[TopTokens]
def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation(
@ -94,4 +116,5 @@ class Generation:
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
)

View File

@ -1,24 +1,20 @@
import re
from typing import Callable, List, Optional, Tuple
import torch
from transformers import (
RepetitionPenaltyLogitsProcessor,
PreTrainedTokenizerBase,
)
from typing import List, Tuple, Optional
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from text_generation_server.utils.logits_process import (
static_warper,
HeterogeneousProcessorWrapper,
HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousTemperatureLogitsWarper,
HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper,
HeterogeneousTypicalLogitsWarper,
HeterogeneousProcessorWrapper,
static_warper,
)
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
class NextTokenChooser:
@ -229,11 +225,10 @@ class HeterogeneousNextTokenChooser:
scores = warper(input_ids, scores)
next_ids = self.choice(scores)
next_logprobs = torch.gather(
torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1)
).view(-1)
logprobs = torch.log_softmax(scores, -1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
return next_ids, next_logprobs
return next_ids, next_logprobs, logprobs
def filter(self, indices):
if self.watermark_processor is not None:
@ -339,3 +334,50 @@ class HeterogeneousSampling:
self.greedy_indices = new_greedy_indices
self.sampling_mapping = new_sampling_mapping
return self
def batch_top_tokens(
top_n_tokens: list[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor
) -> Tuple[List[List[int]], List[List[float]]]:
"""Find the top n most likely tokens for a batch of generations.
When multiple tokens have equal probabilities and they don't all fit, the
remaining tokens are also returned.
"""
max_top_n = max(top_n_tokens)
# Early exit when top_n_tokens is not used
if max_top_n == 0:
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
# Ensure top_n doesn't exceed vocab size
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens]
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
nth_highest = torch.gather(
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
)
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
# Find the new "fuzzy" top n values
top_n_indices = (logprobs >= nth_highest).nonzero()
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
# Take a new topk for these new max n values
top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True)
top_n_ishes = top_n_ishes.tolist()
top_indices = top_k.indices.tolist()
top_values = top_k.values.tolist()
return (
[
idxs[:n] if req_n > 0 else []
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)
],
[
vals[:n] if req_n > 0 else []
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)
],
)