# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Vincent Brouwers <vincent.brouwers@ing.com>
This commit is contained in:
parent
4486f78cf9
commit
211b54ac41
|
@ -37,6 +37,7 @@ pub(crate) async fn generation_task(
|
|||
batch_size: Vec<u32>,
|
||||
sequence_length: u32,
|
||||
decode_length: u32,
|
||||
top_n_tokens: Option<u32>,
|
||||
n_runs: usize,
|
||||
warmups: usize,
|
||||
parameters: NextTokenChooserParameters,
|
||||
|
@ -48,7 +49,7 @@ pub(crate) async fn generation_task(
|
|||
// End task if a message is received on shutdown_receiver
|
||||
// _shutdown_guard_sender will be dropped once the task is finished
|
||||
tokio::select! {
|
||||
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, parameters, client, run_sender.clone()) => {
|
||||
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => {
|
||||
if let Err(err) = res {
|
||||
run_sender.send(Err(err)).await.unwrap_or(());
|
||||
}
|
||||
|
@ -64,6 +65,7 @@ async fn generate_runs(
|
|||
batch_size: Vec<u32>,
|
||||
sequence_length: u32,
|
||||
decode_length: u32,
|
||||
top_n_tokens: Option<u32>,
|
||||
n_runs: usize,
|
||||
warmups: usize,
|
||||
parameters: NextTokenChooserParameters,
|
||||
|
@ -82,6 +84,7 @@ async fn generate_runs(
|
|||
b,
|
||||
decode_length,
|
||||
parameters.clone(),
|
||||
top_n_tokens,
|
||||
&mut client,
|
||||
)
|
||||
.await?;
|
||||
|
@ -97,6 +100,7 @@ async fn generate_runs(
|
|||
b,
|
||||
decode_length,
|
||||
parameters.clone(),
|
||||
top_n_tokens,
|
||||
&mut client,
|
||||
)
|
||||
.await?;
|
||||
|
@ -130,6 +134,7 @@ async fn prefill(
|
|||
batch_size: u32,
|
||||
decode_length: u32,
|
||||
parameters: NextTokenChooserParameters,
|
||||
top_n_tokens: Option<u32>,
|
||||
client: &mut ShardedClient,
|
||||
) -> Result<(Prefill, CachedBatch), ClientError> {
|
||||
// Create requests
|
||||
|
@ -145,6 +150,7 @@ async fn prefill(
|
|||
stop_sequences: vec![],
|
||||
ignore_eos_token: true, // Will not stop even if a eos token is generated
|
||||
}),
|
||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ pub async fn run(
|
|||
batch_size: Vec<u32>,
|
||||
sequence_length: u32,
|
||||
decode_length: u32,
|
||||
top_n_tokens: Option<u32>,
|
||||
n_runs: usize,
|
||||
warmups: usize,
|
||||
temperature: Option<f32>,
|
||||
|
@ -70,6 +71,7 @@ pub async fn run(
|
|||
batch_size.clone(),
|
||||
sequence_length,
|
||||
decode_length,
|
||||
top_n_tokens,
|
||||
n_runs,
|
||||
warmups,
|
||||
parameters,
|
||||
|
@ -130,6 +132,7 @@ pub async fn run(
|
|||
tokenizer_name,
|
||||
sequence_length,
|
||||
decode_length,
|
||||
top_n_tokens,
|
||||
n_runs,
|
||||
warmups,
|
||||
temperature,
|
||||
|
|
|
@ -93,6 +93,11 @@ struct Args {
|
|||
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||
#[clap(long, env)]
|
||||
do_sample: bool,
|
||||
|
||||
/// Generation parameter in case you want to specifically test/debug particular
|
||||
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||
#[clap(long, env)]
|
||||
top_n_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
@ -117,6 +122,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
watermark,
|
||||
do_sample,
|
||||
master_shard_uds_path,
|
||||
top_n_tokens,
|
||||
} = args;
|
||||
|
||||
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
|
||||
|
@ -173,6 +179,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
batch_size,
|
||||
sequence_length,
|
||||
decode_length,
|
||||
top_n_tokens,
|
||||
runs,
|
||||
warmups,
|
||||
temperature,
|
||||
|
|
|
@ -7,6 +7,7 @@ pub(crate) fn parameters_table(
|
|||
tokenizer_name: String,
|
||||
sequence_length: u32,
|
||||
decode_length: u32,
|
||||
top_n_tokens: Option<u32>,
|
||||
n_runs: usize,
|
||||
warmups: usize,
|
||||
temperature: Option<f32>,
|
||||
|
@ -24,6 +25,7 @@ pub(crate) fn parameters_table(
|
|||
builder.push_record(["Model", &tokenizer_name]);
|
||||
builder.push_record(["Sequence Length", &sequence_length.to_string()]);
|
||||
builder.push_record(["Decode Length", &decode_length.to_string()]);
|
||||
builder.push_record(["Top N Tokens", &format!("{top_n_tokens:?}")]);
|
||||
builder.push_record(["N Runs", &n_runs.to_string()]);
|
||||
builder.push_record(["Warmups", &warmups.to_string()]);
|
||||
builder.push_record(["Temperature", &format!("{temperature:?}")]);
|
||||
|
|
|
@ -75,6 +75,7 @@ class Client:
|
|||
typical_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
decoder_input_details: bool = False,
|
||||
top_n_tokens: Optional[int] = None,
|
||||
) -> Response:
|
||||
"""
|
||||
Given a prompt, generate the following text
|
||||
|
@ -113,6 +114,8 @@ class Client:
|
|||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
decoder_input_details (`bool`):
|
||||
Return the decoder input token logprobs and ids
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
|
||||
Returns:
|
||||
Response: generated response
|
||||
|
@ -134,6 +137,7 @@ class Client:
|
|||
typical_p=typical_p,
|
||||
watermark=watermark,
|
||||
decoder_input_details=decoder_input_details,
|
||||
top_n_tokens=top_n_tokens
|
||||
)
|
||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||
|
||||
|
@ -164,6 +168,7 @@ class Client:
|
|||
truncate: Optional[int] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
top_n_tokens: Optional[int] = None,
|
||||
) -> Iterator[StreamResponse]:
|
||||
"""
|
||||
Given a prompt, generate the following stream of tokens
|
||||
|
@ -198,6 +203,8 @@ class Client:
|
|||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||
watermark (`bool`):
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
|
||||
Returns:
|
||||
Iterator[StreamResponse]: stream of generated tokens
|
||||
|
@ -219,6 +226,7 @@ class Client:
|
|||
truncate=truncate,
|
||||
typical_p=typical_p,
|
||||
watermark=watermark,
|
||||
top_n_tokens=top_n_tokens,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||
|
||||
|
@ -317,6 +325,7 @@ class AsyncClient:
|
|||
typical_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
decoder_input_details: bool = False,
|
||||
top_n_tokens: Optional[int] = None,
|
||||
) -> Response:
|
||||
"""
|
||||
Given a prompt, generate the following text asynchronously
|
||||
|
@ -355,6 +364,8 @@ class AsyncClient:
|
|||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
decoder_input_details (`bool`):
|
||||
Return the decoder input token logprobs and ids
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
|
||||
Returns:
|
||||
Response: generated response
|
||||
|
@ -376,6 +387,7 @@ class AsyncClient:
|
|||
truncate=truncate,
|
||||
typical_p=typical_p,
|
||||
watermark=watermark,
|
||||
top_n_tokens=top_n_tokens,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||
|
||||
|
@ -404,6 +416,7 @@ class AsyncClient:
|
|||
truncate: Optional[int] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
top_n_tokens: Optional[int] = None,
|
||||
) -> AsyncIterator[StreamResponse]:
|
||||
"""
|
||||
Given a prompt, generate the following stream of tokens asynchronously
|
||||
|
@ -438,6 +451,8 @@ class AsyncClient:
|
|||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||
watermark (`bool`):
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
|
||||
Returns:
|
||||
AsyncIterator[StreamResponse]: stream of generated tokens
|
||||
|
@ -459,6 +474,7 @@ class AsyncClient:
|
|||
truncate=truncate,
|
||||
typical_p=typical_p,
|
||||
watermark=watermark,
|
||||
top_n_tokens=top_n_tokens,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||
|
||||
|
|
|
@ -39,6 +39,8 @@ class Parameters(BaseModel):
|
|||
details: bool = False
|
||||
# Get decoder input token logprobs and ids
|
||||
decoder_input_details: bool = False
|
||||
# Return the N most likely tokens at each step
|
||||
top_n_tokens: Optional[int]
|
||||
|
||||
@validator("best_of")
|
||||
def valid_best_of(cls, field_value, values):
|
||||
|
@ -101,6 +103,12 @@ class Parameters(BaseModel):
|
|||
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
|
||||
return v
|
||||
|
||||
@validator("top_n_tokens")
|
||||
def valid_top_n_tokens(cls, v):
|
||||
if v is not None and v <= 0:
|
||||
raise ValidationError("`top_n_tokens` must be strictly positive")
|
||||
return v
|
||||
|
||||
|
||||
class Request(BaseModel):
|
||||
# Prompt
|
||||
|
@ -125,9 +133,7 @@ class Request(BaseModel):
|
|||
and parameters.best_of > 1
|
||||
and field_value
|
||||
):
|
||||
raise ValidationError(
|
||||
"`best_of` != 1 is not supported when `stream` == True"
|
||||
)
|
||||
raise ValidationError("`best_of` != 1 is not supported when `stream` == True")
|
||||
return field_value
|
||||
|
||||
|
||||
|
@ -179,6 +185,8 @@ class BestOfSequence(BaseModel):
|
|||
prefill: List[InputToken]
|
||||
# Generated tokens
|
||||
tokens: List[Token]
|
||||
# Most likely tokens
|
||||
top_tokens: Optional[List[List[Token]]]
|
||||
|
||||
|
||||
# `generate` details
|
||||
|
@ -193,6 +201,8 @@ class Details(BaseModel):
|
|||
prefill: List[InputToken]
|
||||
# Generated tokens
|
||||
tokens: List[Token]
|
||||
# Most likely tokens
|
||||
top_tokens: Optional[List[List[Token]]]
|
||||
# Additional sequences when using the `best_of` parameter
|
||||
best_of_sequences: Optional[List[BestOfSequence]]
|
||||
|
||||
|
@ -219,6 +229,8 @@ class StreamDetails(BaseModel):
|
|||
class StreamResponse(BaseModel):
|
||||
# Generated token
|
||||
token: Token
|
||||
# Most likely tokens
|
||||
top_tokens: Optional[List[Token]]
|
||||
# Complete generated text
|
||||
# Only available when the generation is finished
|
||||
generated_text: Optional[str]
|
||||
|
|
|
@ -159,6 +159,14 @@ struct Args {
|
|||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
|
||||
/// This is the maximum allowed value for clients to set `top_n_tokens`.
|
||||
/// `top_n_tokens is used to return information about the the `n` most likely
|
||||
/// tokens at each generation step, instead of just the sampled token. This
|
||||
/// information can be used for downstream tasks like for classification or
|
||||
/// ranking.
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
|
||||
/// This is the maximum allowed input length (expressed in number of tokens)
|
||||
/// for users. The larger this value, the longer prompt users can send which
|
||||
/// can impact the overall memory required to handle the load.
|
||||
|
@ -929,6 +937,8 @@ fn spawn_webserver(
|
|||
args.max_best_of.to_string(),
|
||||
"--max-stop-sequences".to_string(),
|
||||
args.max_stop_sequences.to_string(),
|
||||
"--max-top-n-tokens".to_string(),
|
||||
args.max_top_n_tokens.to_string(),
|
||||
"--max-input-length".to_string(),
|
||||
args.max_input_length.to_string(),
|
||||
"--max-total-tokens".to_string(),
|
||||
|
|
|
@ -91,6 +91,8 @@ message Request {
|
|||
StoppingCriteriaParameters stopping_parameters = 5;
|
||||
/// Return prefill logprobs
|
||||
bool prefill_logprobs = 6;
|
||||
/// Return most likely n tokens
|
||||
uint32 top_n_tokens = 7;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
|
@ -141,6 +143,17 @@ message PrefillTokens {
|
|||
repeated string texts = 3;
|
||||
}
|
||||
|
||||
message TopTokens {
|
||||
/// Top Token IDs
|
||||
repeated uint32 ids = 1;
|
||||
/// Top Logprobs
|
||||
repeated float logprobs = 2;
|
||||
/// Top Token Texts
|
||||
repeated string texts = 3;
|
||||
/// If the tokens are special
|
||||
repeated bool is_special = 6;
|
||||
}
|
||||
|
||||
message Generation {
|
||||
/// Request ID
|
||||
uint64 request_id = 1;
|
||||
|
@ -156,6 +169,8 @@ message Generation {
|
|||
bool token_is_special = 6;
|
||||
/// Complete generated text
|
||||
optional GeneratedText generated_text = 7;
|
||||
/// Top tokens
|
||||
TopTokens top_tokens = 8;
|
||||
}
|
||||
|
||||
message FilterBatchRequest {
|
||||
|
|
|
@ -131,6 +131,7 @@ impl Client {
|
|||
ignore_eos_token: false,
|
||||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
}
|
||||
|
|
|
@ -50,6 +50,7 @@ impl Health {
|
|||
stop_sequences: vec![],
|
||||
ignore_eos_token: false,
|
||||
}),
|
||||
top_n_tokens: 0,
|
||||
};
|
||||
let batch = Batch {
|
||||
id: BATCH_ID,
|
||||
|
|
|
@ -138,12 +138,15 @@ impl Infer {
|
|||
&self,
|
||||
request: GenerateRequest,
|
||||
) -> Result<InferResponse, InferError> {
|
||||
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
||||
|
||||
// Create stream and keep semaphore permit as long as generate lives
|
||||
let (_permit, mut stream) = self.generate_stream(request).await?;
|
||||
|
||||
// Return values
|
||||
let mut result_prefill = Vec::new();
|
||||
let mut result_tokens = Vec::new();
|
||||
let mut result_top_tokens = Vec::new();
|
||||
let mut result_generated_text = None;
|
||||
let mut result_start = None;
|
||||
let mut result_queued = None;
|
||||
|
@ -164,7 +167,10 @@ impl Infer {
|
|||
.collect();
|
||||
}
|
||||
// Push last token
|
||||
InferStreamResponse::Token(token) => result_tokens.push(token),
|
||||
InferStreamResponse::Intermediate { token, top_tokens } => {
|
||||
result_tokens.push(token);
|
||||
result_top_tokens.push(top_tokens);
|
||||
}
|
||||
// Final message
|
||||
// Set return values
|
||||
InferStreamResponse::End {
|
||||
|
@ -172,8 +178,10 @@ impl Infer {
|
|||
generated_text,
|
||||
start,
|
||||
queued,
|
||||
top_tokens,
|
||||
} => {
|
||||
result_tokens.push(token);
|
||||
result_top_tokens.push(top_tokens);
|
||||
result_generated_text = Some(generated_text);
|
||||
result_start = Some(start);
|
||||
result_queued = Some(queued)
|
||||
|
@ -191,6 +199,11 @@ impl Infer {
|
|||
generated_text,
|
||||
queued,
|
||||
start,
|
||||
top_tokens: if use_top_tokens {
|
||||
result_top_tokens
|
||||
} else {
|
||||
Vec::new()
|
||||
},
|
||||
})
|
||||
} else {
|
||||
let err = InferError::IncompleteGeneration;
|
||||
|
@ -520,6 +533,26 @@ fn send_responses(
|
|||
special: generation.token_is_special,
|
||||
};
|
||||
|
||||
// generation.top_tokens
|
||||
|
||||
let mut top_tokens = Vec::new();
|
||||
if let Some(top_tokens_) = generation.top_tokens {
|
||||
top_tokens.extend(
|
||||
top_tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(top_tokens_.logprobs.into_iter())
|
||||
.zip(top_tokens_.texts.into_iter())
|
||||
.zip(top_tokens_.is_special.into_iter())
|
||||
.map(|(((id, logprob), text), special)| Token {
|
||||
id,
|
||||
text,
|
||||
logprob,
|
||||
special,
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
if let Some(generated_text) = generation.generated_text {
|
||||
// Generation has ended
|
||||
stopped = true;
|
||||
|
@ -527,6 +560,7 @@ fn send_responses(
|
|||
entry.response_tx.send_timeout(
|
||||
Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text,
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
|
@ -536,7 +570,7 @@ fn send_responses(
|
|||
} else {
|
||||
// Send message
|
||||
entry.response_tx.send_timeout(
|
||||
Ok(InferStreamResponse::Token(token)),
|
||||
Ok(InferStreamResponse::Intermediate { token, top_tokens }),
|
||||
Duration::from_millis(10),
|
||||
)?;
|
||||
}
|
||||
|
@ -566,10 +600,14 @@ pub(crate) enum InferStreamResponse {
|
|||
// Optional first message
|
||||
Prefill(PrefillTokens),
|
||||
// Intermediate messages
|
||||
Token(Token),
|
||||
Intermediate {
|
||||
token: Token,
|
||||
top_tokens: Vec<Token>,
|
||||
},
|
||||
// Last message
|
||||
End {
|
||||
token: Token,
|
||||
top_tokens: Vec<Token>,
|
||||
generated_text: GeneratedText,
|
||||
start: Instant,
|
||||
queued: Instant,
|
||||
|
@ -583,6 +621,7 @@ pub(crate) struct InferResponse {
|
|||
pub(crate) generated_text: GeneratedText,
|
||||
pub(crate) queued: Instant,
|
||||
pub(crate) start: Instant,
|
||||
pub(crate) top_tokens: Vec<Vec<Token>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
|
|
@ -135,6 +135,9 @@ pub(crate) struct GenerateParameters {
|
|||
example = "null"
|
||||
)]
|
||||
pub seed: Option<u64>,
|
||||
#[serde(default)]
|
||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
||||
pub top_n_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
fn default_max_new_tokens() -> u32 {
|
||||
|
@ -158,6 +161,7 @@ fn default_parameters() -> GenerateParameters {
|
|||
details: false,
|
||||
decoder_input_details: false,
|
||||
seed: None,
|
||||
top_n_tokens: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -235,6 +239,8 @@ pub(crate) struct BestOfSequence {
|
|||
pub seed: Option<u64>,
|
||||
pub prefill: Vec<PrefillToken>,
|
||||
pub tokens: Vec<Token>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub top_tokens: Vec<Vec<Token>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
|
@ -249,6 +255,8 @@ pub(crate) struct Details {
|
|||
pub tokens: Vec<Token>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub best_of_sequences: Option<Vec<BestOfSequence>>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub top_tokens: Vec<Vec<Token>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
|
@ -272,6 +280,8 @@ pub(crate) struct StreamDetails {
|
|||
#[derive(Serialize, ToSchema)]
|
||||
pub(crate) struct StreamResponse {
|
||||
pub token: Token,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub top_tokens: Vec<Token>,
|
||||
#[schema(nullable = true, default = "null", example = "test")]
|
||||
pub generated_text: Option<String>,
|
||||
#[schema(nullable = true, default = "null")]
|
||||
|
|
|
@ -29,6 +29,8 @@ struct Args {
|
|||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_length: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
|
@ -75,6 +77,7 @@ fn main() -> Result<(), RouterError> {
|
|||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
|
@ -259,6 +262,7 @@ fn main() -> Result<(), RouterError> {
|
|||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
|
|
|
@ -235,6 +235,7 @@ impl State {
|
|||
truncate: entry.request.truncate,
|
||||
parameters: Some(entry.request.parameters.clone()),
|
||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||
top_n_tokens: entry.request.top_n_tokens,
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
|
@ -328,6 +329,7 @@ mod tests {
|
|||
max_new_tokens: 1,
|
||||
stop_sequences: vec![],
|
||||
},
|
||||
top_n_tokens: 0,
|
||||
},
|
||||
response_tx,
|
||||
span: info_span!("entry"),
|
||||
|
|
|
@ -158,7 +158,7 @@ async fn generate(
|
|||
add_prompt = Some(req.inputs.clone());
|
||||
}
|
||||
|
||||
let details = req.parameters.details || req.parameters.decoder_input_details;
|
||||
let details: bool = req.parameters.details || req.parameters.decoder_input_details;
|
||||
|
||||
// Inference
|
||||
let (response, best_of_responses) = match req.parameters.best_of {
|
||||
|
@ -191,6 +191,7 @@ async fn generate(
|
|||
generated_tokens: response.generated_text.generated_tokens,
|
||||
prefill: response.prefill,
|
||||
tokens: response.tokens,
|
||||
top_tokens: response.top_tokens,
|
||||
seed: response.generated_text.seed,
|
||||
}
|
||||
})
|
||||
|
@ -204,6 +205,7 @@ async fn generate(
|
|||
tokens: response.tokens,
|
||||
seed: response.generated_text.seed,
|
||||
best_of_sequences,
|
||||
top_tokens: response.top_tokens,
|
||||
})
|
||||
}
|
||||
false => None,
|
||||
|
@ -385,12 +387,16 @@ async fn generate_stream(
|
|||
// Prefill is ignored
|
||||
InferStreamResponse::Prefill(_) => {}
|
||||
// Yield event for every new token
|
||||
InferStreamResponse::Token(token) => {
|
||||
InferStreamResponse::Intermediate{
|
||||
token,
|
||||
top_tokens,
|
||||
} => {
|
||||
tracing::debug!(parent: &span, "Token: {:?}", token);
|
||||
|
||||
// StreamResponse
|
||||
let stream_token = StreamResponse {
|
||||
token,
|
||||
top_tokens: top_tokens,
|
||||
generated_text: None,
|
||||
details: None,
|
||||
};
|
||||
|
@ -403,6 +409,7 @@ async fn generate_stream(
|
|||
generated_text,
|
||||
start,
|
||||
queued,
|
||||
top_tokens,
|
||||
} => {
|
||||
// Token details
|
||||
let details = match details {
|
||||
|
@ -451,6 +458,7 @@ async fn generate_stream(
|
|||
|
||||
let stream_token = StreamResponse {
|
||||
token,
|
||||
top_tokens: top_tokens,
|
||||
generated_text: Some(output_text),
|
||||
details
|
||||
};
|
||||
|
@ -509,6 +517,7 @@ pub async fn run(
|
|||
max_concurrent_requests: usize,
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_top_n_tokens: u32,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
waiting_served_ratio: f32,
|
||||
|
@ -571,6 +580,7 @@ pub async fn run(
|
|||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
|
|
|
@ -15,6 +15,7 @@ pub struct Validation {
|
|||
/// Validation parameters
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_top_n_tokens: u32,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
/// Channel to communicate with the background tokenization task
|
||||
|
@ -27,6 +28,7 @@ impl Validation {
|
|||
tokenizer: Option<Tokenizer>,
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_top_n_tokens: u32,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
) -> Self {
|
||||
|
@ -54,6 +56,7 @@ impl Validation {
|
|||
max_best_of,
|
||||
sender,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
}
|
||||
|
@ -142,6 +145,7 @@ impl Validation {
|
|||
seed,
|
||||
watermark,
|
||||
decoder_input_details,
|
||||
top_n_tokens,
|
||||
..
|
||||
} = request.parameters;
|
||||
|
||||
|
@ -218,6 +222,15 @@ impl Validation {
|
|||
}
|
||||
};
|
||||
|
||||
let top_n_tokens = top_n_tokens
|
||||
.map(|value| {
|
||||
if value > self.max_top_n_tokens {
|
||||
return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
|
||||
}
|
||||
Ok(value)
|
||||
})
|
||||
.unwrap_or(Ok(0))?;
|
||||
|
||||
// Check if inputs is empty
|
||||
if request.inputs.is_empty() {
|
||||
return Err(EmptyInput);
|
||||
|
@ -263,6 +276,7 @@ impl Validation {
|
|||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||
parameters,
|
||||
stopping_parameters,
|
||||
top_n_tokens: top_n_tokens,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -336,6 +350,7 @@ pub(crate) struct ValidGenerateRequest {
|
|||
pub decoder_input_details: bool,
|
||||
pub parameters: NextTokenChooserParameters,
|
||||
pub stopping_parameters: StoppingCriteriaParameters,
|
||||
pub top_n_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
|
@ -350,6 +365,10 @@ pub enum ValidationError {
|
|||
BestOfSeed,
|
||||
#[error("`best_of` != 1 is not supported when streaming tokens")]
|
||||
BestOfStream,
|
||||
#[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
|
||||
TopNTokens(u32, u32),
|
||||
#[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
|
||||
TopNTokensDisabled,
|
||||
#[error("`decoder_input_details` == true is not supported when streaming tokens")]
|
||||
PrefillDetailsStream,
|
||||
#[error("`temperature` must be strictly positive")]
|
||||
|
@ -391,14 +410,16 @@ mod tests {
|
|||
let tokenizer = None;
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_input_length = 4;
|
||||
let max_total_tokens = 5;
|
||||
let max_top_n_tokens = 4;
|
||||
let max_input_length = 5;
|
||||
let max_total_tokens = 6;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
|
@ -418,14 +439,16 @@ mod tests {
|
|||
let tokenizer = Some(get_tokenizer().await);
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_input_length = 4;
|
||||
let max_total_tokens = 5;
|
||||
let max_top_n_tokens = 4;
|
||||
let max_input_length = 5;
|
||||
let max_total_tokens = 6;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
|
@ -435,7 +458,7 @@ mod tests {
|
|||
.validate_input("Hello".to_string(), None, max_new_tokens)
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
|
||||
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
||||
_ => panic!("Unexpected not max new tokens"),
|
||||
}
|
||||
}
|
||||
|
@ -445,14 +468,16 @@ mod tests {
|
|||
let tokenizer = Some(get_tokenizer().await);
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_input_length = 4;
|
||||
let max_total_tokens = 5;
|
||||
let max_top_n_tokens = 4;
|
||||
let max_input_length = 5;
|
||||
let max_total_tokens = 6;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
|
@ -477,14 +502,16 @@ mod tests {
|
|||
let tokenizer = Some(get_tokenizer().await);
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_input_length = 4;
|
||||
let max_total_tokens = 5;
|
||||
let max_top_n_tokens = 4;
|
||||
let max_input_length = 5;
|
||||
let max_total_tokens = 6;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
|
@ -531,4 +558,75 @@ mod tests {
|
|||
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
|
||||
assert_eq!(valid_request.parameters.top_p, 1.0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validation_top_n_tokens() {
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequences = 3;
|
||||
let max_top_n_tokens = 4;
|
||||
let max_input_length = 5;
|
||||
let max_total_tokens = 6;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
match validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: Some(5),
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::TopNTokens(4, 5)) => (),
|
||||
_ => panic!("Unexpected top_n_tokens"),
|
||||
}
|
||||
|
||||
validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: Some(4),
|
||||
max_new_tokens: 1,
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: Some(0),
|
||||
max_new_tokens: 1,
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let valid_request = validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: None,
|
||||
max_new_tokens: 1,
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(valid_request.top_n_tokens, 0);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import torch
|
||||
from text_generation_server.utils.tokens import (
|
||||
StopSequenceCriteria,
|
||||
StoppingCriteria,
|
||||
FinishReason,
|
||||
batch_top_tokens,
|
||||
)
|
||||
|
||||
|
||||
|
@ -42,3 +44,22 @@ def test_stopping_criteria_max():
|
|||
assert criteria(1, "") == (False, None)
|
||||
assert criteria(1, "") == (False, None)
|
||||
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
||||
|
||||
def test_batch_top_tokens():
|
||||
top_n_tokens = [0, 2, 3, 4, 5]
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens)
|
||||
inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5)
|
||||
|
||||
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, top_n_tokens_tensor, inp_logprobs)
|
||||
|
||||
assert topn_tok_ids[0] == []
|
||||
assert topn_tok_ids[1] == [0, 3]
|
||||
assert topn_tok_ids[2] == [0, 3, 1, 4]
|
||||
assert topn_tok_ids[3] == [0, 3, 1, 4]
|
||||
assert topn_tok_ids[4] == [0, 3, 1, 4, 2]
|
||||
|
||||
assert topn_tok_logprobs[0] == []
|
||||
assert topn_tok_logprobs[1] == [-1, -2]
|
||||
assert topn_tok_logprobs[2] == [-1, -2, -3, -3]
|
||||
assert topn_tok_logprobs[3] == [-1, -2, -3, -3]
|
||||
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4]
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
import torch
|
||||
import inspect
|
||||
|
||||
|
@ -12,6 +13,7 @@ from text_generation_server.models.types import (
|
|||
PrefillTokens,
|
||||
Generation,
|
||||
GeneratedText,
|
||||
TopTokens,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
|
@ -42,6 +44,8 @@ class CausalLMBatch(Batch):
|
|||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
top_n_tokens: List[int]
|
||||
top_n_tokens_tensor: torch.Tensor
|
||||
|
||||
# Metadata used for padding
|
||||
max_input_length: int
|
||||
|
@ -72,6 +76,7 @@ class CausalLMBatch(Batch):
|
|||
inputs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
requests_idx_mapping = {}
|
||||
|
@ -88,6 +93,7 @@ class CausalLMBatch(Batch):
|
|||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||
padding_right_offset = max(
|
||||
|
@ -121,6 +127,9 @@ class CausalLMBatch(Batch):
|
|||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||
top_n_tokens_tensor = torch.tensor(
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||
|
||||
|
@ -138,6 +147,8 @@ class CausalLMBatch(Batch):
|
|||
read_offsets=read_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
max_input_length=max_input_length.item(),
|
||||
padding_right_offset=padding_right_offset,
|
||||
max_tokens=max_tokens,
|
||||
|
@ -163,6 +174,7 @@ class CausalLMBatch(Batch):
|
|||
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
total_remaining_decode_tokens = 0
|
||||
new_padding_right_offset = 0
|
||||
|
@ -184,6 +196,7 @@ class CausalLMBatch(Batch):
|
|||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(self.top_n_tokens[idx])
|
||||
remaining_decode_tokens = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
|
@ -223,6 +236,7 @@ class CausalLMBatch(Batch):
|
|||
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
|
||||
del past_values
|
||||
|
||||
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
|
||||
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
|
||||
|
||||
self.requests = requests
|
||||
|
@ -235,6 +249,8 @@ class CausalLMBatch(Batch):
|
|||
self.read_offsets = read_offsets
|
||||
self.next_token_choosers = next_token_choosers
|
||||
self.stopping_criterias = stopping_criterias
|
||||
self.top_n_tokens = top_n_tokens
|
||||
self.top_n_tokens_tensor = top_n_tokens_tensor
|
||||
self.max_input_length = max_input_length
|
||||
self.padding_right_offset = new_padding_right_offset
|
||||
self.max_tokens = max_tokens
|
||||
|
@ -262,6 +278,7 @@ class CausalLMBatch(Batch):
|
|||
all_input_ids = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
max_tokens = 0
|
||||
|
||||
# Batch tensors
|
||||
|
@ -269,6 +286,7 @@ class CausalLMBatch(Batch):
|
|||
attention_mask = None
|
||||
position_ids = None
|
||||
past_key_values = []
|
||||
top_n_tokens_tensor = None
|
||||
|
||||
# Used for slicing correctly inside the tensors
|
||||
# Equivalent to a cumsum on batch sizes
|
||||
|
@ -281,6 +299,7 @@ class CausalLMBatch(Batch):
|
|||
all_input_ids.extend(batch.all_input_ids)
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
top_n_tokens.extend(batch.top_n_tokens)
|
||||
|
||||
if i == 0:
|
||||
requests_idx_mapping = batch.requests_idx_mapping
|
||||
|
@ -310,6 +329,12 @@ class CausalLMBatch(Batch):
|
|||
(total_batch_size, max_input_length + padding_right_offset),
|
||||
)
|
||||
|
||||
if top_n_tokens_tensor is None:
|
||||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||
total_batch_size,
|
||||
)
|
||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||
|
||||
# We need to slice the attention mask to remove padding from previous steps
|
||||
# and to remove unused allocated space
|
||||
left_offset = max_input_length - batch.max_input_length
|
||||
|
@ -438,6 +463,8 @@ class CausalLMBatch(Batch):
|
|||
read_offsets=read_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
max_input_length=max_input_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||
|
@ -549,6 +576,12 @@ class CausalLM(Model):
|
|||
generations: List[Generation] = []
|
||||
stopped = True
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens,
|
||||
batch.top_n_tokens_tensor,
|
||||
torch.softmax(logits[:, -1], -1),
|
||||
)
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
|
@ -559,6 +592,9 @@ class CausalLM(Model):
|
|||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
batch.top_n_tokens,
|
||||
batch_top_token_ids,
|
||||
batch_top_token_logprobs,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
|
@ -571,6 +607,9 @@ class CausalLM(Model):
|
|||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
top_n_tokens,
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
|
@ -637,6 +676,24 @@ class CausalLM(Model):
|
|||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
if top_n_tokens > 0:
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = TopTokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
|
@ -645,6 +702,7 @@ class CausalLM(Model):
|
|||
next_token_text,
|
||||
next_token_id_squeezed.item() in self.all_special_ids,
|
||||
generated_text,
|
||||
top_tokens,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import math
|
||||
import itertools
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -16,6 +17,7 @@ from text_generation_server.models.types import (
|
|||
PrefillTokens,
|
||||
Generation,
|
||||
GeneratedText,
|
||||
TopTokens,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
|
@ -165,6 +167,8 @@ class FlashCausalLMBatch(Batch):
|
|||
# Generation helpers
|
||||
next_token_chooser: HeterogeneousNextTokenChooser
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
top_n_tokens: List[int]
|
||||
top_n_tokens_tensor: torch.Tensor
|
||||
|
||||
# Number of blocks in this batch
|
||||
blocks: int
|
||||
|
@ -217,6 +221,7 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
next_token_chooser_parameters = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
|
@ -259,6 +264,7 @@ class FlashCausalLMBatch(Batch):
|
|||
)
|
||||
max_new_tokens = stopping_criteria.max_new_tokens
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
|
||||
# Paged attention
|
||||
# Remove one as the first token des not have a past
|
||||
|
@ -352,6 +358,9 @@ class FlashCausalLMBatch(Batch):
|
|||
prefill_next_token_indices = torch.tensor(
|
||||
prefill_next_token_indices, dtype=torch.int64, device=device
|
||||
)
|
||||
top_n_tokens_tensor = torch.tensor(
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
|
@ -378,6 +387,8 @@ class FlashCausalLMBatch(Batch):
|
|||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
)
|
||||
|
@ -417,6 +428,7 @@ class FlashCausalLMBatch(Batch):
|
|||
read_offsets = []
|
||||
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
blocks = 0
|
||||
max_blocks = 0
|
||||
|
@ -443,6 +455,8 @@ class FlashCausalLMBatch(Batch):
|
|||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
|
||||
top_n_tokens.append(self.top_n_tokens[idx])
|
||||
|
||||
remaining_tokens = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
|
@ -487,6 +501,7 @@ class FlashCausalLMBatch(Batch):
|
|||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||
slots = self.slots[slot_filtering_indices]
|
||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||
|
||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
|
@ -518,6 +533,8 @@ class FlashCausalLMBatch(Batch):
|
|||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
)
|
||||
|
@ -566,6 +583,9 @@ class FlashCausalLMBatch(Batch):
|
|||
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||
(total_batch_size, max_length)
|
||||
)
|
||||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||
total_batch_size,
|
||||
)
|
||||
|
||||
start_slots = []
|
||||
block_tables = []
|
||||
|
@ -577,6 +597,7 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
next_token_chooser_parameters = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
# Cumulative length
|
||||
cumulative_batch_size = 0
|
||||
|
@ -602,6 +623,7 @@ class FlashCausalLMBatch(Batch):
|
|||
position_ids[start_index:end_index] = batch.position_ids
|
||||
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
||||
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||
slots[slots_start_index:slots_end_index] = batch.slots
|
||||
|
||||
all_input_ids_tensor[
|
||||
|
@ -624,6 +646,8 @@ class FlashCausalLMBatch(Batch):
|
|||
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
|
||||
top_n_tokens.extend(batch.top_n_tokens)
|
||||
|
||||
# Update
|
||||
cumulative_batch_size += len(batch)
|
||||
cumulative_slots += len(batch.slots)
|
||||
|
@ -666,6 +690,8 @@ class FlashCausalLMBatch(Batch):
|
|||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
)
|
||||
|
@ -831,10 +857,14 @@ class FlashCausalLM(Model):
|
|||
else:
|
||||
next_token_logits = out
|
||||
|
||||
next_input_ids, next_token_logprobs = batch.next_token_chooser(
|
||||
next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
|
||||
)
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
||||
)
|
||||
|
||||
if prefill:
|
||||
if len(batch) > 1 and prefill_logprobs:
|
||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||
|
@ -931,8 +961,11 @@ class FlashCausalLM(Model):
|
|||
batch.all_input_ids,
|
||||
batch.next_token_chooser.do_sample,
|
||||
batch.next_token_chooser.seeds,
|
||||
batch.top_n_tokens,
|
||||
next_token_ids,
|
||||
next_token_logprobs,
|
||||
batch_top_token_ids,
|
||||
batch_top_token_logprobs,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
|
@ -945,8 +978,11 @@ class FlashCausalLM(Model):
|
|||
all_input_ids,
|
||||
do_sample,
|
||||
seed,
|
||||
top_n_tokens,
|
||||
next_token_id,
|
||||
next_token_logprob,
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
# Append next token to all tokens
|
||||
all_input_ids.append(next_token_id)
|
||||
|
@ -1005,6 +1041,24 @@ class FlashCausalLM(Model):
|
|||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
if top_n_tokens > 0:
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = TopTokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
|
@ -1013,6 +1067,7 @@ class FlashCausalLM(Model):
|
|||
next_token_text,
|
||||
next_token_id in self.all_special_ids,
|
||||
generated_text,
|
||||
top_tokens,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
|
|
@ -763,6 +763,8 @@ class IdeficsCausalLM(Model):
|
|||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
top_tokens=None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
|
@ -771,6 +773,7 @@ class IdeficsCausalLM(Model):
|
|||
next_token_text,
|
||||
next_token_id_squeezed.item() in self.all_special_ids,
|
||||
generated_text,
|
||||
top_tokens
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
@ -11,6 +12,7 @@ from text_generation_server.models.types import (
|
|||
Batch,
|
||||
Generation,
|
||||
PrefillTokens,
|
||||
TopTokens,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
|
@ -48,6 +50,8 @@ class Seq2SeqLMBatch(Batch):
|
|||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
top_n_tokens: List[int]
|
||||
top_n_tokens_tensor: torch.Tensor
|
||||
|
||||
# Metadata used for padding
|
||||
max_input_length: int
|
||||
|
@ -78,7 +82,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
inputs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
top_n_tokens = []
|
||||
decoder_input_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
|
@ -97,6 +101,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||
padding_right_offset = max(
|
||||
|
@ -126,6 +131,9 @@ class Seq2SeqLMBatch(Batch):
|
|||
prefix_offsets.append(0)
|
||||
read_offsets.append(1)
|
||||
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
||||
top_n_tokens_tensor = torch.tensor(
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||
|
||||
|
@ -146,6 +154,8 @@ class Seq2SeqLMBatch(Batch):
|
|||
read_offsets=read_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
max_input_length=max_input_length.item(),
|
||||
max_decoder_input_length=1,
|
||||
padding_right_offset=padding_right_offset,
|
||||
|
@ -173,6 +183,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
max_input_length = 0
|
||||
max_decoder_input_length = 0
|
||||
|
@ -204,6 +215,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(self.top_n_tokens[idx])
|
||||
remaining_decode_tokens = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
|
@ -239,6 +251,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
layer[2] = layer[2][keep_indices, :, -max_input_length:]
|
||||
layer[3] = layer[3][keep_indices, :, -max_input_length:]
|
||||
|
||||
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
|
||||
max_tokens = (
|
||||
len(request_ids) * (max_input_length + max_decoder_input_length)
|
||||
+ remaining_decode_tokens
|
||||
|
@ -254,6 +267,8 @@ class Seq2SeqLMBatch(Batch):
|
|||
self.read_offsets = read_offsets
|
||||
self.next_token_choosers = next_token_choosers
|
||||
self.stopping_criterias = stopping_criterias
|
||||
self.top_n_tokens = top_n_tokens
|
||||
self.top_n_tokens_tensor = top_n_tokens_tensor
|
||||
self.max_input_length = max_input_length
|
||||
self.max_decoder_input_length = max_decoder_input_length
|
||||
self.padding_right_offset = padding_right_offset
|
||||
|
@ -289,6 +304,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
read_offsets = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
max_tokens = 0
|
||||
|
||||
# Batch tensors
|
||||
|
@ -296,6 +312,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
decoder_input_ids = None
|
||||
decoder_attention_mask = None
|
||||
encoder_last_hidden_state = None
|
||||
top_n_tokens_tensor = None
|
||||
past_key_values = []
|
||||
|
||||
# Used for slicing correctly inside the tensors
|
||||
|
@ -312,6 +329,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
read_offsets.extend(batch.read_offsets)
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
top_n_tokens.extend(batch.top_n_tokens)
|
||||
|
||||
if i == 0:
|
||||
requests_idx_mapping = batch.requests_idx_mapping
|
||||
|
@ -384,6 +402,12 @@ class Seq2SeqLMBatch(Batch):
|
|||
),
|
||||
)
|
||||
|
||||
if top_n_tokens_tensor is None:
|
||||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||
total_batch_size,
|
||||
)
|
||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||
|
||||
# Copy to correct indices
|
||||
encoder_last_hidden_state[
|
||||
start_index:end_index, -batch.max_input_length :, :
|
||||
|
@ -488,6 +512,8 @@ class Seq2SeqLMBatch(Batch):
|
|||
read_offsets=read_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
max_input_length=max_input_length,
|
||||
max_decoder_input_length=max_decoder_input_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
|
@ -613,6 +639,12 @@ class Seq2SeqLM(Model):
|
|||
batch.past_key_values,
|
||||
)
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens,
|
||||
batch.top_n_tokens_tensor,
|
||||
torch.softmax(logits[:, -1], -1),
|
||||
)
|
||||
|
||||
# Finished requests
|
||||
generations: List[Generation] = []
|
||||
stopped = True
|
||||
|
@ -628,6 +660,9 @@ class Seq2SeqLM(Model):
|
|||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_decoder_input_ids,
|
||||
batch.top_n_tokens,
|
||||
batch_top_token_ids,
|
||||
batch_top_token_logprobs,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
|
@ -641,6 +676,9 @@ class Seq2SeqLM(Model):
|
|||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_decoder_input_ids,
|
||||
top_n_tokens,
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
|
@ -698,6 +736,24 @@ class Seq2SeqLM(Model):
|
|||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
if top_n_tokens > 0:
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = TopTokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
|
@ -706,6 +762,7 @@ class Seq2SeqLM(Model):
|
|||
next_token_text,
|
||||
next_token_id_squeezed.item() in self.all_special_ids,
|
||||
generated_text,
|
||||
top_tokens,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from functools import total_ordering
|
||||
import torch
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
@ -71,6 +72,25 @@ class PrefillTokens:
|
|||
return len(self.token_ids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TopTokens:
|
||||
token_ids: List[int]
|
||||
logprobs: List[float]
|
||||
texts: List[str]
|
||||
is_special: List[bool]
|
||||
|
||||
def to_pb(self) -> generate_pb2.TopTokens:
|
||||
return generate_pb2.TopTokens(
|
||||
ids=self.token_ids,
|
||||
logprobs=self.logprobs,
|
||||
texts=self.texts,
|
||||
is_special=self.is_special,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.token_ids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Generation:
|
||||
request_id: int
|
||||
|
@ -80,6 +100,8 @@ class Generation:
|
|||
token_text: str
|
||||
token_is_special: bool
|
||||
generated_text: Optional[GeneratedText]
|
||||
# Optional for now, since it's not yet supported for every model.
|
||||
top_tokens: Optional[TopTokens]
|
||||
|
||||
def to_pb(self) -> generate_pb2.Generation:
|
||||
return generate_pb2.Generation(
|
||||
|
@ -94,4 +116,5 @@ class Generation:
|
|||
generated_text=self.generated_text.to_pb()
|
||||
if self.generated_text is not None
|
||||
else None,
|
||||
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
|
||||
)
|
||||
|
|
|
@ -1,24 +1,20 @@
|
|||
import re
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
from text_generation_server.utils.logits_process import (
|
||||
static_warper,
|
||||
HeterogeneousProcessorWrapper,
|
||||
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
||||
HeterogeneousTemperatureLogitsWarper,
|
||||
HeterogeneousTopKLogitsWarper,
|
||||
HeterogeneousTopPLogitsWarper,
|
||||
HeterogeneousTypicalLogitsWarper,
|
||||
HeterogeneousProcessorWrapper,
|
||||
static_warper,
|
||||
)
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||
|
||||
|
||||
class NextTokenChooser:
|
||||
|
@ -229,11 +225,10 @@ class HeterogeneousNextTokenChooser:
|
|||
scores = warper(input_ids, scores)
|
||||
|
||||
next_ids = self.choice(scores)
|
||||
next_logprobs = torch.gather(
|
||||
torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1)
|
||||
).view(-1)
|
||||
logprobs = torch.log_softmax(scores, -1)
|
||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||
|
||||
return next_ids, next_logprobs
|
||||
return next_ids, next_logprobs, logprobs
|
||||
|
||||
def filter(self, indices):
|
||||
if self.watermark_processor is not None:
|
||||
|
@ -339,3 +334,50 @@ class HeterogeneousSampling:
|
|||
self.greedy_indices = new_greedy_indices
|
||||
self.sampling_mapping = new_sampling_mapping
|
||||
return self
|
||||
|
||||
|
||||
def batch_top_tokens(
|
||||
top_n_tokens: list[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
"""Find the top n most likely tokens for a batch of generations.
|
||||
|
||||
When multiple tokens have equal probabilities and they don't all fit, the
|
||||
remaining tokens are also returned.
|
||||
"""
|
||||
max_top_n = max(top_n_tokens)
|
||||
# Early exit when top_n_tokens is not used
|
||||
if max_top_n == 0:
|
||||
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
|
||||
|
||||
# Ensure top_n doesn't exceed vocab size
|
||||
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens]
|
||||
|
||||
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
||||
# Sorted topk is faster than torch.sort() since we only need a small subset
|
||||
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
|
||||
nth_highest = torch.gather(
|
||||
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
|
||||
)
|
||||
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
|
||||
|
||||
# Find the new "fuzzy" top n values
|
||||
top_n_indices = (logprobs >= nth_highest).nonzero()
|
||||
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
|
||||
|
||||
# Take a new topk for these new max n values
|
||||
top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True)
|
||||
|
||||
top_n_ishes = top_n_ishes.tolist()
|
||||
top_indices = top_k.indices.tolist()
|
||||
top_values = top_k.values.tolist()
|
||||
|
||||
return (
|
||||
[
|
||||
idxs[:n] if req_n > 0 else []
|
||||
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)
|
||||
],
|
||||
[
|
||||
vals[:n] if req_n > 0 else []
|
||||
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)
|
||||
],
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue