# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Vincent Brouwers <vincent.brouwers@ing.com>
This commit is contained in:
Nicolas Patry 2023-08-28 11:43:47 +02:00 committed by GitHub
parent 4486f78cf9
commit 211b54ac41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 529 additions and 34 deletions

View File

@ -37,6 +37,7 @@ pub(crate) async fn generation_task(
batch_size: Vec<u32>, batch_size: Vec<u32>,
sequence_length: u32, sequence_length: u32,
decode_length: u32, decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize, n_runs: usize,
warmups: usize, warmups: usize,
parameters: NextTokenChooserParameters, parameters: NextTokenChooserParameters,
@ -48,7 +49,7 @@ pub(crate) async fn generation_task(
// End task if a message is received on shutdown_receiver // End task if a message is received on shutdown_receiver
// _shutdown_guard_sender will be dropped once the task is finished // _shutdown_guard_sender will be dropped once the task is finished
tokio::select! { tokio::select! {
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, parameters, client, run_sender.clone()) => { res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => {
if let Err(err) = res { if let Err(err) = res {
run_sender.send(Err(err)).await.unwrap_or(()); run_sender.send(Err(err)).await.unwrap_or(());
} }
@ -64,6 +65,7 @@ async fn generate_runs(
batch_size: Vec<u32>, batch_size: Vec<u32>,
sequence_length: u32, sequence_length: u32,
decode_length: u32, decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize, n_runs: usize,
warmups: usize, warmups: usize,
parameters: NextTokenChooserParameters, parameters: NextTokenChooserParameters,
@ -82,6 +84,7 @@ async fn generate_runs(
b, b,
decode_length, decode_length,
parameters.clone(), parameters.clone(),
top_n_tokens,
&mut client, &mut client,
) )
.await?; .await?;
@ -97,6 +100,7 @@ async fn generate_runs(
b, b,
decode_length, decode_length,
parameters.clone(), parameters.clone(),
top_n_tokens,
&mut client, &mut client,
) )
.await?; .await?;
@ -130,6 +134,7 @@ async fn prefill(
batch_size: u32, batch_size: u32,
decode_length: u32, decode_length: u32,
parameters: NextTokenChooserParameters, parameters: NextTokenChooserParameters,
top_n_tokens: Option<u32>,
client: &mut ShardedClient, client: &mut ShardedClient,
) -> Result<(Prefill, CachedBatch), ClientError> { ) -> Result<(Prefill, CachedBatch), ClientError> {
// Create requests // Create requests
@ -145,6 +150,7 @@ async fn prefill(
stop_sequences: vec![], stop_sequences: vec![],
ignore_eos_token: true, // Will not stop even if a eos token is generated ignore_eos_token: true, // Will not stop even if a eos token is generated
}), }),
top_n_tokens: top_n_tokens.unwrap_or(0),
}) })
.collect(); .collect();

View File

@ -22,6 +22,7 @@ pub async fn run(
batch_size: Vec<u32>, batch_size: Vec<u32>,
sequence_length: u32, sequence_length: u32,
decode_length: u32, decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize, n_runs: usize,
warmups: usize, warmups: usize,
temperature: Option<f32>, temperature: Option<f32>,
@ -70,6 +71,7 @@ pub async fn run(
batch_size.clone(), batch_size.clone(),
sequence_length, sequence_length,
decode_length, decode_length,
top_n_tokens,
n_runs, n_runs,
warmups, warmups,
parameters, parameters,
@ -130,6 +132,7 @@ pub async fn run(
tokenizer_name, tokenizer_name,
sequence_length, sequence_length,
decode_length, decode_length,
top_n_tokens,
n_runs, n_runs,
warmups, warmups,
temperature, temperature,

View File

@ -93,6 +93,11 @@ struct Args {
/// decoding strategies, for full doc refer to the `text-generation-server` /// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)] #[clap(long, env)]
do_sample: bool, do_sample: bool,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
top_n_tokens: Option<u32>,
} }
fn main() -> Result<(), Box<dyn std::error::Error>> { fn main() -> Result<(), Box<dyn std::error::Error>> {
@ -117,6 +122,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
watermark, watermark,
do_sample, do_sample,
master_shard_uds_path, master_shard_uds_path,
top_n_tokens,
} = args; } = args;
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]); let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
@ -173,6 +179,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
batch_size, batch_size,
sequence_length, sequence_length,
decode_length, decode_length,
top_n_tokens,
runs, runs,
warmups, warmups,
temperature, temperature,

View File

@ -7,6 +7,7 @@ pub(crate) fn parameters_table(
tokenizer_name: String, tokenizer_name: String,
sequence_length: u32, sequence_length: u32,
decode_length: u32, decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize, n_runs: usize,
warmups: usize, warmups: usize,
temperature: Option<f32>, temperature: Option<f32>,
@ -24,6 +25,7 @@ pub(crate) fn parameters_table(
builder.push_record(["Model", &tokenizer_name]); builder.push_record(["Model", &tokenizer_name]);
builder.push_record(["Sequence Length", &sequence_length.to_string()]); builder.push_record(["Sequence Length", &sequence_length.to_string()]);
builder.push_record(["Decode Length", &decode_length.to_string()]); builder.push_record(["Decode Length", &decode_length.to_string()]);
builder.push_record(["Top N Tokens", &format!("{top_n_tokens:?}")]);
builder.push_record(["N Runs", &n_runs.to_string()]); builder.push_record(["N Runs", &n_runs.to_string()]);
builder.push_record(["Warmups", &warmups.to_string()]); builder.push_record(["Warmups", &warmups.to_string()]);
builder.push_record(["Temperature", &format!("{temperature:?}")]); builder.push_record(["Temperature", &format!("{temperature:?}")]);

View File

@ -75,6 +75,7 @@ class Client:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
decoder_input_details: bool = False, decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text Given a prompt, generate the following text
@ -113,6 +114,8 @@ class Client:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`): decoder_input_details (`bool`):
Return the decoder input token logprobs and ids Return the decoder input token logprobs and ids
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns: Returns:
Response: generated response Response: generated response
@ -134,6 +137,7 @@ class Client:
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
decoder_input_details=decoder_input_details, decoder_input_details=decoder_input_details,
top_n_tokens=top_n_tokens
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -164,6 +168,7 @@ class Client:
truncate: Optional[int] = None, truncate: Optional[int] = None,
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
top_n_tokens: Optional[int] = None,
) -> Iterator[StreamResponse]: ) -> Iterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens Given a prompt, generate the following stream of tokens
@ -198,6 +203,8 @@ class Client:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns: Returns:
Iterator[StreamResponse]: stream of generated tokens Iterator[StreamResponse]: stream of generated tokens
@ -219,6 +226,7 @@ class Client:
truncate=truncate, truncate=truncate,
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
top_n_tokens=top_n_tokens,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)
@ -317,6 +325,7 @@ class AsyncClient:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
decoder_input_details: bool = False, decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text asynchronously Given a prompt, generate the following text asynchronously
@ -355,6 +364,8 @@ class AsyncClient:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`): decoder_input_details (`bool`):
Return the decoder input token logprobs and ids Return the decoder input token logprobs and ids
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns: Returns:
Response: generated response Response: generated response
@ -376,6 +387,7 @@ class AsyncClient:
truncate=truncate, truncate=truncate,
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
top_n_tokens=top_n_tokens,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -404,6 +416,7 @@ class AsyncClient:
truncate: Optional[int] = None, truncate: Optional[int] = None,
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
top_n_tokens: Optional[int] = None,
) -> AsyncIterator[StreamResponse]: ) -> AsyncIterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens asynchronously Given a prompt, generate the following stream of tokens asynchronously
@ -438,6 +451,8 @@ class AsyncClient:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns: Returns:
AsyncIterator[StreamResponse]: stream of generated tokens AsyncIterator[StreamResponse]: stream of generated tokens
@ -459,6 +474,7 @@ class AsyncClient:
truncate=truncate, truncate=truncate,
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
top_n_tokens=top_n_tokens,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)

View File

@ -39,6 +39,8 @@ class Parameters(BaseModel):
details: bool = False details: bool = False
# Get decoder input token logprobs and ids # Get decoder input token logprobs and ids
decoder_input_details: bool = False decoder_input_details: bool = False
# Return the N most likely tokens at each step
top_n_tokens: Optional[int]
@validator("best_of") @validator("best_of")
def valid_best_of(cls, field_value, values): def valid_best_of(cls, field_value, values):
@ -101,6 +103,12 @@ class Parameters(BaseModel):
raise ValidationError("`typical_p` must be > 0.0 and < 1.0") raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
return v return v
@validator("top_n_tokens")
def valid_top_n_tokens(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_n_tokens` must be strictly positive")
return v
class Request(BaseModel): class Request(BaseModel):
# Prompt # Prompt
@ -125,9 +133,7 @@ class Request(BaseModel):
and parameters.best_of > 1 and parameters.best_of > 1
and field_value and field_value
): ):
raise ValidationError( raise ValidationError("`best_of` != 1 is not supported when `stream` == True")
"`best_of` != 1 is not supported when `stream` == True"
)
return field_value return field_value
@ -179,6 +185,8 @@ class BestOfSequence(BaseModel):
prefill: List[InputToken] prefill: List[InputToken]
# Generated tokens # Generated tokens
tokens: List[Token] tokens: List[Token]
# Most likely tokens
top_tokens: Optional[List[List[Token]]]
# `generate` details # `generate` details
@ -193,6 +201,8 @@ class Details(BaseModel):
prefill: List[InputToken] prefill: List[InputToken]
# Generated tokens # Generated tokens
tokens: List[Token] tokens: List[Token]
# Most likely tokens
top_tokens: Optional[List[List[Token]]]
# Additional sequences when using the `best_of` parameter # Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]] best_of_sequences: Optional[List[BestOfSequence]]
@ -219,6 +229,8 @@ class StreamDetails(BaseModel):
class StreamResponse(BaseModel): class StreamResponse(BaseModel):
# Generated token # Generated token
token: Token token: Token
# Most likely tokens
top_tokens: Optional[List[Token]]
# Complete generated text # Complete generated text
# Only available when the generation is finished # Only available when the generation is finished
generated_text: Optional[str] generated_text: Optional[str]

View File

@ -159,6 +159,14 @@ struct Args {
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_stop_sequences: usize, max_stop_sequences: usize,
/// This is the maximum allowed value for clients to set `top_n_tokens`.
/// `top_n_tokens is used to return information about the the `n` most likely
/// tokens at each generation step, instead of just the sampled token. This
/// information can be used for downstream tasks like for classification or
/// ranking.
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
/// This is the maximum allowed input length (expressed in number of tokens) /// This is the maximum allowed input length (expressed in number of tokens)
/// for users. The larger this value, the longer prompt users can send which /// for users. The larger this value, the longer prompt users can send which
/// can impact the overall memory required to handle the load. /// can impact the overall memory required to handle the load.
@ -929,6 +937,8 @@ fn spawn_webserver(
args.max_best_of.to_string(), args.max_best_of.to_string(),
"--max-stop-sequences".to_string(), "--max-stop-sequences".to_string(),
args.max_stop_sequences.to_string(), args.max_stop_sequences.to_string(),
"--max-top-n-tokens".to_string(),
args.max_top_n_tokens.to_string(),
"--max-input-length".to_string(), "--max-input-length".to_string(),
args.max_input_length.to_string(), args.max_input_length.to_string(),
"--max-total-tokens".to_string(), "--max-total-tokens".to_string(),

View File

@ -91,6 +91,8 @@ message Request {
StoppingCriteriaParameters stopping_parameters = 5; StoppingCriteriaParameters stopping_parameters = 5;
/// Return prefill logprobs /// Return prefill logprobs
bool prefill_logprobs = 6; bool prefill_logprobs = 6;
/// Return most likely n tokens
uint32 top_n_tokens = 7;
} }
message Batch { message Batch {
@ -141,6 +143,17 @@ message PrefillTokens {
repeated string texts = 3; repeated string texts = 3;
} }
message TopTokens {
/// Top Token IDs
repeated uint32 ids = 1;
/// Top Logprobs
repeated float logprobs = 2;
/// Top Token Texts
repeated string texts = 3;
/// If the tokens are special
repeated bool is_special = 6;
}
message Generation { message Generation {
/// Request ID /// Request ID
uint64 request_id = 1; uint64 request_id = 1;
@ -156,6 +169,8 @@ message Generation {
bool token_is_special = 6; bool token_is_special = 6;
/// Complete generated text /// Complete generated text
optional GeneratedText generated_text = 7; optional GeneratedText generated_text = 7;
/// Top tokens
TopTokens top_tokens = 8;
} }
message FilterBatchRequest { message FilterBatchRequest {

View File

@ -131,6 +131,7 @@ impl Client {
ignore_eos_token: false, ignore_eos_token: false,
}), }),
prefill_logprobs: true, prefill_logprobs: true,
top_n_tokens: 20,
}); });
n_tokens += max_input_length; n_tokens += max_input_length;
} }

View File

@ -50,6 +50,7 @@ impl Health {
stop_sequences: vec![], stop_sequences: vec![],
ignore_eos_token: false, ignore_eos_token: false,
}), }),
top_n_tokens: 0,
}; };
let batch = Batch { let batch = Batch {
id: BATCH_ID, id: BATCH_ID,

View File

@ -138,12 +138,15 @@ impl Infer {
&self, &self,
request: GenerateRequest, request: GenerateRequest,
) -> Result<InferResponse, InferError> { ) -> Result<InferResponse, InferError> {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
// Create stream and keep semaphore permit as long as generate lives // Create stream and keep semaphore permit as long as generate lives
let (_permit, mut stream) = self.generate_stream(request).await?; let (_permit, mut stream) = self.generate_stream(request).await?;
// Return values // Return values
let mut result_prefill = Vec::new(); let mut result_prefill = Vec::new();
let mut result_tokens = Vec::new(); let mut result_tokens = Vec::new();
let mut result_top_tokens = Vec::new();
let mut result_generated_text = None; let mut result_generated_text = None;
let mut result_start = None; let mut result_start = None;
let mut result_queued = None; let mut result_queued = None;
@ -164,7 +167,10 @@ impl Infer {
.collect(); .collect();
} }
// Push last token // Push last token
InferStreamResponse::Token(token) => result_tokens.push(token), InferStreamResponse::Intermediate { token, top_tokens } => {
result_tokens.push(token);
result_top_tokens.push(top_tokens);
}
// Final message // Final message
// Set return values // Set return values
InferStreamResponse::End { InferStreamResponse::End {
@ -172,8 +178,10 @@ impl Infer {
generated_text, generated_text,
start, start,
queued, queued,
top_tokens,
} => { } => {
result_tokens.push(token); result_tokens.push(token);
result_top_tokens.push(top_tokens);
result_generated_text = Some(generated_text); result_generated_text = Some(generated_text);
result_start = Some(start); result_start = Some(start);
result_queued = Some(queued) result_queued = Some(queued)
@ -191,6 +199,11 @@ impl Infer {
generated_text, generated_text,
queued, queued,
start, start,
top_tokens: if use_top_tokens {
result_top_tokens
} else {
Vec::new()
},
}) })
} else { } else {
let err = InferError::IncompleteGeneration; let err = InferError::IncompleteGeneration;
@ -520,6 +533,26 @@ fn send_responses(
special: generation.token_is_special, special: generation.token_is_special,
}; };
// generation.top_tokens
let mut top_tokens = Vec::new();
if let Some(top_tokens_) = generation.top_tokens {
top_tokens.extend(
top_tokens_
.ids
.into_iter()
.zip(top_tokens_.logprobs.into_iter())
.zip(top_tokens_.texts.into_iter())
.zip(top_tokens_.is_special.into_iter())
.map(|(((id, logprob), text), special)| Token {
id,
text,
logprob,
special,
}),
)
}
if let Some(generated_text) = generation.generated_text { if let Some(generated_text) = generation.generated_text {
// Generation has ended // Generation has ended
stopped = true; stopped = true;
@ -527,6 +560,7 @@ fn send_responses(
entry.response_tx.send_timeout( entry.response_tx.send_timeout(
Ok(InferStreamResponse::End { Ok(InferStreamResponse::End {
token, token,
top_tokens,
generated_text, generated_text,
queued: entry.queue_time, queued: entry.queue_time,
start: entry.batch_time.unwrap(), start: entry.batch_time.unwrap(),
@ -536,7 +570,7 @@ fn send_responses(
} else { } else {
// Send message // Send message
entry.response_tx.send_timeout( entry.response_tx.send_timeout(
Ok(InferStreamResponse::Token(token)), Ok(InferStreamResponse::Intermediate { token, top_tokens }),
Duration::from_millis(10), Duration::from_millis(10),
)?; )?;
} }
@ -566,10 +600,14 @@ pub(crate) enum InferStreamResponse {
// Optional first message // Optional first message
Prefill(PrefillTokens), Prefill(PrefillTokens),
// Intermediate messages // Intermediate messages
Token(Token), Intermediate {
token: Token,
top_tokens: Vec<Token>,
},
// Last message // Last message
End { End {
token: Token, token: Token,
top_tokens: Vec<Token>,
generated_text: GeneratedText, generated_text: GeneratedText,
start: Instant, start: Instant,
queued: Instant, queued: Instant,
@ -583,6 +621,7 @@ pub(crate) struct InferResponse {
pub(crate) generated_text: GeneratedText, pub(crate) generated_text: GeneratedText,
pub(crate) queued: Instant, pub(crate) queued: Instant,
pub(crate) start: Instant, pub(crate) start: Instant,
pub(crate) top_tokens: Vec<Vec<Token>>,
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]

View File

@ -135,6 +135,9 @@ pub(crate) struct GenerateParameters {
example = "null" example = "null"
)] )]
pub seed: Option<u64>, pub seed: Option<u64>,
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>,
} }
fn default_max_new_tokens() -> u32 { fn default_max_new_tokens() -> u32 {
@ -158,6 +161,7 @@ fn default_parameters() -> GenerateParameters {
details: false, details: false,
decoder_input_details: false, decoder_input_details: false,
seed: None, seed: None,
top_n_tokens: None,
} }
} }
@ -235,6 +239,8 @@ pub(crate) struct BestOfSequence {
pub seed: Option<u64>, pub seed: Option<u64>,
pub prefill: Vec<PrefillToken>, pub prefill: Vec<PrefillToken>,
pub tokens: Vec<Token>, pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
@ -249,6 +255,8 @@ pub(crate) struct Details {
pub tokens: Vec<Token>, pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub best_of_sequences: Option<Vec<BestOfSequence>>, pub best_of_sequences: Option<Vec<BestOfSequence>>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
@ -272,6 +280,8 @@ pub(crate) struct StreamDetails {
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse { pub(crate) struct StreamResponse {
pub token: Token, pub token: Token,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Token>,
#[schema(nullable = true, default = "null", example = "test")] #[schema(nullable = true, default = "null", example = "test")]
pub generated_text: Option<String>, pub generated_text: Option<String>,
#[schema(nullable = true, default = "null")] #[schema(nullable = true, default = "null")]

View File

@ -29,6 +29,8 @@ struct Args {
max_best_of: usize, max_best_of: usize,
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_stop_sequences: usize, max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)] #[clap(default_value = "1024", long, env)]
max_input_length: usize, max_input_length: usize,
#[clap(default_value = "2048", long, env)] #[clap(default_value = "2048", long, env)]
@ -75,6 +77,7 @@ fn main() -> Result<(), RouterError> {
max_concurrent_requests, max_concurrent_requests,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
waiting_served_ratio, waiting_served_ratio,
@ -259,6 +262,7 @@ fn main() -> Result<(), RouterError> {
max_concurrent_requests, max_concurrent_requests,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
waiting_served_ratio, waiting_served_ratio,

View File

@ -235,6 +235,7 @@ impl State {
truncate: entry.request.truncate, truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()), parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()), stopping_parameters: Some(entry.request.stopping_parameters.clone()),
top_n_tokens: entry.request.top_n_tokens,
}); });
// Set batch_time // Set batch_time
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());
@ -328,6 +329,7 @@ mod tests {
max_new_tokens: 1, max_new_tokens: 1,
stop_sequences: vec![], stop_sequences: vec![],
}, },
top_n_tokens: 0,
}, },
response_tx, response_tx,
span: info_span!("entry"), span: info_span!("entry"),

View File

@ -158,7 +158,7 @@ async fn generate(
add_prompt = Some(req.inputs.clone()); add_prompt = Some(req.inputs.clone());
} }
let details = req.parameters.details || req.parameters.decoder_input_details; let details: bool = req.parameters.details || req.parameters.decoder_input_details;
// Inference // Inference
let (response, best_of_responses) = match req.parameters.best_of { let (response, best_of_responses) = match req.parameters.best_of {
@ -191,6 +191,7 @@ async fn generate(
generated_tokens: response.generated_text.generated_tokens, generated_tokens: response.generated_text.generated_tokens,
prefill: response.prefill, prefill: response.prefill,
tokens: response.tokens, tokens: response.tokens,
top_tokens: response.top_tokens,
seed: response.generated_text.seed, seed: response.generated_text.seed,
} }
}) })
@ -204,6 +205,7 @@ async fn generate(
tokens: response.tokens, tokens: response.tokens,
seed: response.generated_text.seed, seed: response.generated_text.seed,
best_of_sequences, best_of_sequences,
top_tokens: response.top_tokens,
}) })
} }
false => None, false => None,
@ -385,12 +387,16 @@ async fn generate_stream(
// Prefill is ignored // Prefill is ignored
InferStreamResponse::Prefill(_) => {} InferStreamResponse::Prefill(_) => {}
// Yield event for every new token // Yield event for every new token
InferStreamResponse::Token(token) => { InferStreamResponse::Intermediate{
token,
top_tokens,
} => {
tracing::debug!(parent: &span, "Token: {:?}", token); tracing::debug!(parent: &span, "Token: {:?}", token);
// StreamResponse // StreamResponse
let stream_token = StreamResponse { let stream_token = StreamResponse {
token, token,
top_tokens: top_tokens,
generated_text: None, generated_text: None,
details: None, details: None,
}; };
@ -403,6 +409,7 @@ async fn generate_stream(
generated_text, generated_text,
start, start,
queued, queued,
top_tokens,
} => { } => {
// Token details // Token details
let details = match details { let details = match details {
@ -451,6 +458,7 @@ async fn generate_stream(
let stream_token = StreamResponse { let stream_token = StreamResponse {
token, token,
top_tokens: top_tokens,
generated_text: Some(output_text), generated_text: Some(output_text),
details details
}; };
@ -509,6 +517,7 @@ pub async fn run(
max_concurrent_requests: usize, max_concurrent_requests: usize,
max_best_of: usize, max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
waiting_served_ratio: f32, waiting_served_ratio: f32,
@ -571,6 +580,7 @@ pub async fn run(
tokenizer, tokenizer,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
); );

View File

@ -15,6 +15,7 @@ pub struct Validation {
/// Validation parameters /// Validation parameters
max_best_of: usize, max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
/// Channel to communicate with the background tokenization task /// Channel to communicate with the background tokenization task
@ -27,6 +28,7 @@ impl Validation {
tokenizer: Option<Tokenizer>, tokenizer: Option<Tokenizer>,
max_best_of: usize, max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
) -> Self { ) -> Self {
@ -54,6 +56,7 @@ impl Validation {
max_best_of, max_best_of,
sender, sender,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
} }
@ -142,6 +145,7 @@ impl Validation {
seed, seed,
watermark, watermark,
decoder_input_details, decoder_input_details,
top_n_tokens,
.. ..
} = request.parameters; } = request.parameters;
@ -218,6 +222,15 @@ impl Validation {
} }
}; };
let top_n_tokens = top_n_tokens
.map(|value| {
if value > self.max_top_n_tokens {
return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
}
Ok(value)
})
.unwrap_or(Ok(0))?;
// Check if inputs is empty // Check if inputs is empty
if request.inputs.is_empty() { if request.inputs.is_empty() {
return Err(EmptyInput); return Err(EmptyInput);
@ -263,6 +276,7 @@ impl Validation {
truncate: truncate.unwrap_or(self.max_input_length) as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32,
parameters, parameters,
stopping_parameters, stopping_parameters,
top_n_tokens: top_n_tokens,
}) })
} }
@ -336,6 +350,7 @@ pub(crate) struct ValidGenerateRequest {
pub decoder_input_details: bool, pub decoder_input_details: bool,
pub parameters: NextTokenChooserParameters, pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters, pub stopping_parameters: StoppingCriteriaParameters,
pub top_n_tokens: u32,
} }
#[derive(Error, Debug)] #[derive(Error, Debug)]
@ -350,6 +365,10 @@ pub enum ValidationError {
BestOfSeed, BestOfSeed,
#[error("`best_of` != 1 is not supported when streaming tokens")] #[error("`best_of` != 1 is not supported when streaming tokens")]
BestOfStream, BestOfStream,
#[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
TopNTokens(u32, u32),
#[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
TopNTokensDisabled,
#[error("`decoder_input_details` == true is not supported when streaming tokens")] #[error("`decoder_input_details` == true is not supported when streaming tokens")]
PrefillDetailsStream, PrefillDetailsStream,
#[error("`temperature` must be strictly positive")] #[error("`temperature` must be strictly positive")]
@ -391,14 +410,16 @@ mod tests {
let tokenizer = None; let tokenizer = None;
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_top_n_tokens = 4;
let max_total_tokens = 5; let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1; let workers = 1;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
); );
@ -418,14 +439,16 @@ mod tests {
let tokenizer = Some(get_tokenizer().await); let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_top_n_tokens = 4;
let max_total_tokens = 5; let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1; let workers = 1;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
); );
@ -435,7 +458,7 @@ mod tests {
.validate_input("Hello".to_string(), None, max_new_tokens) .validate_input("Hello".to_string(), None, max_new_tokens)
.await .await
{ {
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (), Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
_ => panic!("Unexpected not max new tokens"), _ => panic!("Unexpected not max new tokens"),
} }
} }
@ -445,14 +468,16 @@ mod tests {
let tokenizer = Some(get_tokenizer().await); let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_top_n_tokens = 4;
let max_total_tokens = 5; let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1; let workers = 1;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
); );
@ -477,14 +502,16 @@ mod tests {
let tokenizer = Some(get_tokenizer().await); let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_top_n_tokens = 4;
let max_total_tokens = 5; let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1; let workers = 1;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
); );
@ -531,4 +558,75 @@ mod tests {
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value. // top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
assert_eq!(valid_request.parameters.top_p, 1.0); assert_eq!(valid_request.parameters.top_p, 1.0);
} }
#[tokio::test]
async fn test_validation_top_n_tokens() {
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
let max_stop_sequences = 3;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1;
let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_length,
max_total_tokens,
);
match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: Some(5),
..default_parameters()
},
})
.await
{
Err(ValidationError::TopNTokens(4, 5)) => (),
_ => panic!("Unexpected top_n_tokens"),
}
validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: Some(4),
max_new_tokens: 1,
..default_parameters()
},
})
.await
.unwrap();
validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: Some(0),
max_new_tokens: 1,
..default_parameters()
},
})
.await
.unwrap();
let valid_request = validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: None,
max_new_tokens: 1,
..default_parameters()
},
})
.await
.unwrap();
assert_eq!(valid_request.top_n_tokens, 0);
}
} }

View File

@ -1,7 +1,9 @@
import torch
from text_generation_server.utils.tokens import ( from text_generation_server.utils.tokens import (
StopSequenceCriteria, StopSequenceCriteria,
StoppingCriteria, StoppingCriteria,
FinishReason, FinishReason,
batch_top_tokens,
) )
@ -42,3 +44,22 @@ def test_stopping_criteria_max():
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH) assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
def test_batch_top_tokens():
top_n_tokens = [0, 2, 3, 4, 5]
top_n_tokens_tensor = torch.tensor(top_n_tokens)
inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5)
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, top_n_tokens_tensor, inp_logprobs)
assert topn_tok_ids[0] == []
assert topn_tok_ids[1] == [0, 3]
assert topn_tok_ids[2] == [0, 3, 1, 4]
assert topn_tok_ids[3] == [0, 3, 1, 4]
assert topn_tok_ids[4] == [0, 3, 1, 4, 2]
assert topn_tok_logprobs[0] == []
assert topn_tok_logprobs[1] == [-1, -2]
assert topn_tok_logprobs[2] == [-1, -2, -3, -3]
assert topn_tok_logprobs[3] == [-1, -2, -3, -3]
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4]

View File

@ -1,3 +1,4 @@
from text_generation_server.utils.tokens import batch_top_tokens
import torch import torch
import inspect import inspect
@ -12,6 +13,7 @@ from text_generation_server.models.types import (
PrefillTokens, PrefillTokens,
Generation, Generation,
GeneratedText, GeneratedText,
TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -42,6 +44,8 @@ class CausalLMBatch(Batch):
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding # Metadata used for padding
max_input_length: int max_input_length: int
@ -72,6 +76,7 @@ class CausalLMBatch(Batch):
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
requests_idx_mapping = {} requests_idx_mapping = {}
@ -88,6 +93,7 @@ class CausalLMBatch(Batch):
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max( padding_right_offset = max(
@ -121,6 +127,9 @@ class CausalLMBatch(Batch):
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens) max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
@ -138,6 +147,8 @@ class CausalLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
max_tokens=max_tokens, max_tokens=max_tokens,
@ -163,6 +174,7 @@ class CausalLMBatch(Batch):
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
total_remaining_decode_tokens = 0 total_remaining_decode_tokens = 0
new_padding_right_offset = 0 new_padding_right_offset = 0
@ -184,6 +196,7 @@ class CausalLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx]) next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx] stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = ( remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
) )
@ -223,6 +236,7 @@ class CausalLMBatch(Batch):
layer[1] = past_values[keep_indices, :, -past_kv_length:, :] layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values del past_values
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
self.requests = requests self.requests = requests
@ -235,6 +249,8 @@ class CausalLMBatch(Batch):
self.read_offsets = read_offsets self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens self.max_tokens = max_tokens
@ -262,6 +278,7 @@ class CausalLMBatch(Batch):
all_input_ids = [] all_input_ids = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
max_tokens = 0 max_tokens = 0
# Batch tensors # Batch tensors
@ -269,6 +286,7 @@ class CausalLMBatch(Batch):
attention_mask = None attention_mask = None
position_ids = None position_ids = None
past_key_values = [] past_key_values = []
top_n_tokens_tensor = None
# Used for slicing correctly inside the tensors # Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes # Equivalent to a cumsum on batch sizes
@ -281,6 +299,7 @@ class CausalLMBatch(Batch):
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0: if i == 0:
requests_idx_mapping = batch.requests_idx_mapping requests_idx_mapping = batch.requests_idx_mapping
@ -310,6 +329,12 @@ class CausalLMBatch(Batch):
(total_batch_size, max_input_length + padding_right_offset), (total_batch_size, max_input_length + padding_right_offset),
) )
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# We need to slice the attention mask to remove padding from previous steps # We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space # and to remove unused allocated space
left_offset = max_input_length - batch.max_input_length left_offset = max_input_length - batch.max_input_length
@ -438,6 +463,8 @@ class CausalLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length, max_input_length=max_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last, keys_head_dim_last=batches[0].keys_head_dim_last,
@ -549,6 +576,12 @@ class CausalLM(Model):
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True stopped = True
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.softmax(logits[:, -1], -1),
)
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.requests, batch.requests,
@ -559,6 +592,9 @@ class CausalLM(Model):
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
) )
# For each member of the batch # For each member of the batch
@ -571,6 +607,9 @@ class CausalLM(Model):
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser( next_token_id, logprobs = next_token_chooser(
@ -637,6 +676,24 @@ class CausalLM(Model):
else: else:
prefill_tokens = None prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
@ -645,6 +702,7 @@ class CausalLM(Model):
next_token_text, next_token_text,
next_token_id_squeezed.item() in self.all_special_ids, next_token_id_squeezed.item() in self.all_special_ids,
generated_text, generated_text,
top_tokens,
) )
generations.append(generation) generations.append(generation)

View File

@ -1,5 +1,6 @@
import math import math
import itertools import itertools
from text_generation_server.utils.tokens import batch_top_tokens
import torch import torch
import torch.distributed import torch.distributed
@ -16,6 +17,7 @@ from text_generation_server.models.types import (
PrefillTokens, PrefillTokens,
Generation, Generation,
GeneratedText, GeneratedText,
TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
@ -165,6 +167,8 @@ class FlashCausalLMBatch(Batch):
# Generation helpers # Generation helpers
next_token_chooser: HeterogeneousNextTokenChooser next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Number of blocks in this batch # Number of blocks in this batch
blocks: int blocks: int
@ -217,6 +221,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters = [] next_token_chooser_parameters = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
@ -259,6 +264,7 @@ class FlashCausalLMBatch(Batch):
) )
max_new_tokens = stopping_criteria.max_new_tokens max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
# Paged attention # Paged attention
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
@ -352,6 +358,9 @@ class FlashCausalLMBatch(Batch):
prefill_next_token_indices = torch.tensor( prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device prefill_next_token_indices, dtype=torch.int64, device=device
) )
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
@ -378,6 +387,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
) )
@ -417,6 +428,7 @@ class FlashCausalLMBatch(Batch):
read_offsets = [] read_offsets = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
blocks = 0 blocks = 0
max_blocks = 0 max_blocks = 0
@ -443,6 +455,8 @@ class FlashCausalLMBatch(Batch):
stopping_criteria = self.stopping_criterias[idx] stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_tokens = ( remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
) )
@ -487,6 +501,7 @@ class FlashCausalLMBatch(Batch):
input_lengths_tensor = self.input_lengths_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices] slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)
@ -518,6 +533,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
) )
@ -566,6 +583,9 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length) (total_batch_size, max_length)
) )
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
start_slots = [] start_slots = []
block_tables = [] block_tables = []
@ -577,6 +597,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters = [] next_token_chooser_parameters = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
# Cumulative length # Cumulative length
cumulative_batch_size = 0 cumulative_batch_size = 0
@ -602,6 +623,7 @@ class FlashCausalLMBatch(Batch):
position_ids[start_index:end_index] = batch.position_ids position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
slots[slots_start_index:slots_end_index] = batch.slots slots[slots_start_index:slots_end_index] = batch.slots
all_input_ids_tensor[ all_input_ids_tensor[
@ -624,6 +646,8 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
# Update # Update
cumulative_batch_size += len(batch) cumulative_batch_size += len(batch)
cumulative_slots += len(batch.slots) cumulative_slots += len(batch.slots)
@ -666,6 +690,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
) )
@ -831,10 +857,14 @@ class FlashCausalLM(Model):
else: else:
next_token_logits = out next_token_logits = out
next_input_ids, next_token_logprobs = batch.next_token_chooser( next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
)
if prefill: if prefill:
if len(batch) > 1 and prefill_logprobs: if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
@ -931,8 +961,11 @@ class FlashCausalLM(Model):
batch.all_input_ids, batch.all_input_ids,
batch.next_token_chooser.do_sample, batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds, batch.next_token_chooser.seeds,
batch.top_n_tokens,
next_token_ids, next_token_ids,
next_token_logprobs, next_token_logprobs,
batch_top_token_ids,
batch_top_token_logprobs,
) )
# For each member of the batch # For each member of the batch
@ -945,8 +978,11 @@ class FlashCausalLM(Model):
all_input_ids, all_input_ids,
do_sample, do_sample,
seed, seed,
top_n_tokens,
next_token_id, next_token_id,
next_token_logprob, next_token_logprob,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
# Append next token to all tokens # Append next token to all tokens
all_input_ids.append(next_token_id) all_input_ids.append(next_token_id)
@ -1005,6 +1041,24 @@ class FlashCausalLM(Model):
else: else:
prefill_tokens = None prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
@ -1013,6 +1067,7 @@ class FlashCausalLM(Model):
next_token_text, next_token_text,
next_token_id in self.all_special_ids, next_token_id in self.all_special_ids,
generated_text, generated_text,
top_tokens,
) )
generations.append(generation) generations.append(generation)

View File

@ -763,6 +763,8 @@ class IdeficsCausalLM(Model):
else: else:
prefill_tokens = None prefill_tokens = None
top_tokens=None
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
@ -771,6 +773,7 @@ class IdeficsCausalLM(Model):
next_token_text, next_token_text,
next_token_id_squeezed.item() in self.all_special_ids, next_token_id_squeezed.item() in self.all_special_ids,
generated_text, generated_text,
top_tokens
) )
generations.append(generation) generations.append(generation)

View File

@ -1,3 +1,4 @@
from text_generation_server.utils.tokens import batch_top_tokens
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
@ -11,6 +12,7 @@ from text_generation_server.models.types import (
Batch, Batch,
Generation, Generation,
PrefillTokens, PrefillTokens,
TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -48,6 +50,8 @@ class Seq2SeqLMBatch(Batch):
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding # Metadata used for padding
max_input_length: int max_input_length: int
@ -78,7 +82,7 @@ class Seq2SeqLMBatch(Batch):
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
decoder_input_lengths = [] decoder_input_lengths = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
@ -97,6 +101,7 @@ class Seq2SeqLMBatch(Batch):
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max( padding_right_offset = max(
@ -126,6 +131,9 @@ class Seq2SeqLMBatch(Batch):
prefix_offsets.append(0) prefix_offsets.append(0)
read_offsets.append(1) read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens) max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
@ -146,6 +154,8 @@ class Seq2SeqLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
max_decoder_input_length=1, max_decoder_input_length=1,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
@ -173,6 +183,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
max_input_length = 0 max_input_length = 0
max_decoder_input_length = 0 max_decoder_input_length = 0
@ -204,6 +215,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx]) next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx] stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = ( remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
) )
@ -239,6 +251,7 @@ class Seq2SeqLMBatch(Batch):
layer[2] = layer[2][keep_indices, :, -max_input_length:] layer[2] = layer[2][keep_indices, :, -max_input_length:]
layer[3] = layer[3][keep_indices, :, -max_input_length:] layer[3] = layer[3][keep_indices, :, -max_input_length:]
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = ( max_tokens = (
len(request_ids) * (max_input_length + max_decoder_input_length) len(request_ids) * (max_input_length + max_decoder_input_length)
+ remaining_decode_tokens + remaining_decode_tokens
@ -254,6 +267,8 @@ class Seq2SeqLMBatch(Batch):
self.read_offsets = read_offsets self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length self.max_input_length = max_input_length
self.max_decoder_input_length = max_decoder_input_length self.max_decoder_input_length = max_decoder_input_length
self.padding_right_offset = padding_right_offset self.padding_right_offset = padding_right_offset
@ -289,6 +304,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets = [] read_offsets = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
max_tokens = 0 max_tokens = 0
# Batch tensors # Batch tensors
@ -296,6 +312,7 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids = None decoder_input_ids = None
decoder_attention_mask = None decoder_attention_mask = None
encoder_last_hidden_state = None encoder_last_hidden_state = None
top_n_tokens_tensor = None
past_key_values = [] past_key_values = []
# Used for slicing correctly inside the tensors # Used for slicing correctly inside the tensors
@ -312,6 +329,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets.extend(batch.read_offsets) read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0: if i == 0:
requests_idx_mapping = batch.requests_idx_mapping requests_idx_mapping = batch.requests_idx_mapping
@ -384,6 +402,12 @@ class Seq2SeqLMBatch(Batch):
), ),
) )
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# Copy to correct indices # Copy to correct indices
encoder_last_hidden_state[ encoder_last_hidden_state[
start_index:end_index, -batch.max_input_length :, : start_index:end_index, -batch.max_input_length :, :
@ -488,6 +512,8 @@ class Seq2SeqLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length, max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length, max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
@ -613,6 +639,12 @@ class Seq2SeqLM(Model):
batch.past_key_values, batch.past_key_values,
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.softmax(logits[:, -1], -1),
)
# Finished requests # Finished requests
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True stopped = True
@ -628,6 +660,9 @@ class Seq2SeqLM(Model):
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_decoder_input_ids, batch.all_decoder_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
) )
# For each member of the batch # For each member of the batch
@ -641,6 +676,9 @@ class Seq2SeqLM(Model):
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
all_decoder_input_ids, all_decoder_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser( next_token_id, logprobs = next_token_chooser(
@ -698,6 +736,24 @@ class Seq2SeqLM(Model):
else: else:
prefill_tokens = None prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
@ -706,6 +762,7 @@ class Seq2SeqLM(Model):
next_token_text, next_token_text,
next_token_id_squeezed.item() in self.all_special_ids, next_token_id_squeezed.item() in self.all_special_ids,
generated_text, generated_text,
top_tokens,
) )
generations.append(generation) generations.append(generation)

View File

@ -1,3 +1,4 @@
from functools import total_ordering
import torch import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -71,6 +72,25 @@ class PrefillTokens:
return len(self.token_ids) return len(self.token_ids)
@dataclass
class TopTokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
is_special: List[bool]
def to_pb(self) -> generate_pb2.TopTokens:
return generate_pb2.TopTokens(
ids=self.token_ids,
logprobs=self.logprobs,
texts=self.texts,
is_special=self.is_special,
)
def __len__(self):
return len(self.token_ids)
@dataclass @dataclass
class Generation: class Generation:
request_id: int request_id: int
@ -80,6 +100,8 @@ class Generation:
token_text: str token_text: str
token_is_special: bool token_is_special: bool
generated_text: Optional[GeneratedText] generated_text: Optional[GeneratedText]
# Optional for now, since it's not yet supported for every model.
top_tokens: Optional[TopTokens]
def to_pb(self) -> generate_pb2.Generation: def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation( return generate_pb2.Generation(
@ -94,4 +116,5 @@ class Generation:
generated_text=self.generated_text.to_pb() generated_text=self.generated_text.to_pb()
if self.generated_text is not None if self.generated_text is not None
else None, else None,
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
) )

View File

@ -1,24 +1,20 @@
import re import re
from typing import Callable, List, Optional, Tuple
import torch import torch
from transformers import (
RepetitionPenaltyLogitsProcessor,
PreTrainedTokenizerBase,
)
from typing import List, Tuple, Optional
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from text_generation_server.utils.logits_process import ( from text_generation_server.utils.logits_process import (
static_warper, HeterogeneousProcessorWrapper,
HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousTemperatureLogitsWarper, HeterogeneousTemperatureLogitsWarper,
HeterogeneousTopKLogitsWarper, HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper, HeterogeneousTopPLogitsWarper,
HeterogeneousTypicalLogitsWarper, HeterogeneousTypicalLogitsWarper,
HeterogeneousProcessorWrapper, static_warper,
) )
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
class NextTokenChooser: class NextTokenChooser:
@ -229,11 +225,10 @@ class HeterogeneousNextTokenChooser:
scores = warper(input_ids, scores) scores = warper(input_ids, scores)
next_ids = self.choice(scores) next_ids = self.choice(scores)
next_logprobs = torch.gather( logprobs = torch.log_softmax(scores, -1)
torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
).view(-1)
return next_ids, next_logprobs return next_ids, next_logprobs, logprobs
def filter(self, indices): def filter(self, indices):
if self.watermark_processor is not None: if self.watermark_processor is not None:
@ -339,3 +334,50 @@ class HeterogeneousSampling:
self.greedy_indices = new_greedy_indices self.greedy_indices = new_greedy_indices
self.sampling_mapping = new_sampling_mapping self.sampling_mapping = new_sampling_mapping
return self return self
def batch_top_tokens(
top_n_tokens: list[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor
) -> Tuple[List[List[int]], List[List[float]]]:
"""Find the top n most likely tokens for a batch of generations.
When multiple tokens have equal probabilities and they don't all fit, the
remaining tokens are also returned.
"""
max_top_n = max(top_n_tokens)
# Early exit when top_n_tokens is not used
if max_top_n == 0:
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
# Ensure top_n doesn't exceed vocab size
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens]
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
nth_highest = torch.gather(
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
)
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
# Find the new "fuzzy" top n values
top_n_indices = (logprobs >= nth_highest).nonzero()
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
# Take a new topk for these new max n values
top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True)
top_n_ishes = top_n_ishes.tolist()
top_indices = top_k.indices.tolist()
top_values = top_k.values.tolist()
return (
[
idxs[:n] if req_n > 0 else []
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)
],
[
vals[:n] if req_n > 0 else []
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)
],
)