feat: Return logprobs (#8)

This commit is contained in:
OlivierDehaene 2022-12-15 17:03:56 +01:00 committed by GitHub
parent 718096f695
commit 32a253063d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 247 additions and 94 deletions

30
.github/workflows/server-tests.yaml vendored Normal file
View File

@ -0,0 +1,30 @@
name: Server Tests
on:
pull_request:
paths:
- "server/**"
- "proto/**"
jobs:
run_tests:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: 3.9
- name: Loading cache.
uses: actions/cache@v2
id: model_cache
with:
path: ~/.cache/huggingface/
key: models
- name: Install server dependencies
run: |
make install-server
- name: Run tests
run: |
pip install pytest
pytest -sv server/tests

View File

@ -17,6 +17,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
- 45ms per token generation for BLOOM with 8xA100 80GB
- Logits warpers (temperature scaling, topk ...)
- Stop sequences
- Log probabilities
## Officially supported models

View File

@ -27,7 +27,7 @@ message ClearCacheRequest {}
/// Empty response
message ClearCacheResponse {}
message LogitsWarperParameters {
message NextTokenChooserParameters {
/// exponential scaling output probability distribution
float temperature = 1;
/// restricting to the k highest probability elements
@ -52,8 +52,8 @@ message Request {
string inputs = 2;
/// The number of tokens inside inputs
uint32 input_length = 3;
/// Logits Warper Parameters
LogitsWarperParameters parameters = 4;
/// Next Token Chooser Parameters
NextTokenChooserParameters parameters = 4;
/// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 5;
}
@ -71,11 +71,17 @@ message GeneratedText {
/// Request
Request request = 1;
/// Output
string output = 2;
string output_text = 2;
/// Number of generated tokens
uint32 tokens = 3;
uint32 generated_tokens = 3;
/// Tokens
repeated string tokens = 4;
/// Token IDs
repeated uint32 token_ids = 5;
/// Logprobs
repeated float logprobs = 6;
/// Finish reason
string finish_reason = 4;
string finish_reason = 7;
}
message GenerateRequest {

View File

@ -7,7 +7,7 @@ mod sharded_client;
pub use client::Client;
pub use pb::generate::v1::{
Batch, GeneratedText, LogitsWarperParameters, Request, StoppingCriteriaParameters,
Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;

View File

@ -187,9 +187,13 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
let entry = db
.remove(&output.request.unwrap().id)
.expect("ID not found in db. This is a bug.");
let response = InferResponse {
output: output.output,
output_text: output.output_text,
generated_tokens: output.generated_tokens,
token_ids: output.token_ids,
tokens: output.tokens,
logprobs: output.logprobs,
finish_reason: output.finish_reason,
queued: entry.time,
start: entry.batch_time.unwrap(), // unwrap is always valid
@ -202,8 +206,11 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
#[derive(Debug)]
pub(crate) struct InferResponse {
pub(crate) output: String,
pub(crate) tokens: u32,
pub(crate) output_text: String,
pub(crate) generated_tokens: u32,
pub(crate) token_ids: Vec<u32>,
pub(crate) tokens: Vec<String>,
pub(crate) logprobs: Vec<f32>,
pub(crate) finish_reason: String,
pub(crate) queued: Instant,
pub(crate) start: Instant,

View File

@ -5,7 +5,7 @@ use parking_lot::Mutex;
use std::collections::BTreeMap;
use std::sync::Arc;
use text_generation_client::{
Batch, ClientError, LogitsWarperParameters, Request, StoppingCriteriaParameters,
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use tokio::sync::oneshot::Sender;
use tokio::time::Instant;
@ -71,7 +71,7 @@ impl State {
id: *id,
inputs: entry.request.inputs.clone(),
input_length: entry.input_length as u32,
parameters: Some(LogitsWarperParameters::from(
parameters: Some(NextTokenChooserParameters::from(
entry.request.parameters.clone(),
)),
stopping_parameters: Some(StoppingCriteriaParameters::from(
@ -162,7 +162,7 @@ impl Db {
}
}
impl From<GenerateParameters> for LogitsWarperParameters {
impl From<GenerateParameters> for NextTokenChooserParameters {
fn from(parameters: GenerateParameters) -> Self {
Self {
temperature: parameters.temperature,

View File

@ -21,7 +21,10 @@ pub(crate) struct GenerateParameters {
pub do_sample: bool,
#[serde(default = "default_max_new_tokens")]
pub max_new_tokens: u32,
#[serde(default)]
pub stop: Vec<String>,
#[serde(default)]
pub details: bool,
}
fn default_temperature() -> f32 {
@ -52,6 +55,7 @@ fn default_parameters() -> GenerateParameters {
do_sample: default_do_sample(),
max_new_tokens: default_max_new_tokens(),
stop: vec![],
details: false,
}
}
@ -62,10 +66,18 @@ pub(crate) struct GenerateRequest {
pub parameters: GenerateParameters,
}
#[derive(Serialize)]
pub(crate) struct Details {
pub finish_reason: String,
pub generated_tokens: u32,
pub tokens: Vec<(u32, String, f32)>,
}
#[derive(Serialize)]
pub(crate) struct GeneratedText {
pub generated_text: String,
pub finish_reason: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<Details>,
}
#[derive(Serialize)]

View File

@ -1,5 +1,5 @@
use crate::{
Batcher, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation,
Batcher, Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, StatusCode};
@ -54,6 +54,7 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
do_sample: false,
max_new_tokens: 1,
stop: vec![],
details: false,
},
},
)
@ -89,6 +90,7 @@ async fn generate(
})?;
// Validate request
let details = req.0.parameters.details;
let (input_length, validated_request) =
state.validation.validate(req.0).await.map_err(|err| {
tracing::error!("{}", err.to_string());
@ -105,12 +107,31 @@ async fn generate(
err
})?;
// Token details
let details = match details {
true => {
let tokens = response
.token_ids
.into_iter()
.zip(response.tokens.into_iter())
.zip(response.logprobs.into_iter())
.map(|((id, text), logprob)| (id, text, logprob))
.collect();
Some(Details {
finish_reason: response.finish_reason,
generated_tokens: response.generated_tokens,
tokens,
})
}
false => None,
};
// Timings
let total_time = start_time.elapsed();
let validation_time = response.queued - start_time;
let queue_time = response.start - response.queued;
let inference_time = response.end - response.start;
let time_per_token = inference_time / response.tokens;
let time_per_token = inference_time / response.generated_tokens;
// Headers
let mut headers = HeaderMap::new();
@ -141,12 +162,12 @@ async fn generate(
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
tracing::info!("Output: {}", response.output);
tracing::info!("Output: {}", response.output_text);
// Send response
let response = vec![GeneratedText {
generated_text: response.output,
finish_reason: response.finish_reason,
generated_text: response.output_text,
details,
}];
Ok((headers, Json(response)))
}

View File

@ -7,7 +7,7 @@ from text_generation.pb import generate_pb2
@pytest.fixture
def default_pb_parameters():
return generate_pb2.LogitsWarperParameters(
return generate_pb2.NextTokenChooserParameters(
temperature=1.0,
top_k=0,
top_p=1.0,

View File

@ -128,10 +128,12 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
)
assert generated_texts[0].request == default_bloom_batch.requests[0]
assert (
generated_texts[0].tokens
generated_texts[0].generated_tokens
== default_bloom_batch.stopping_criterias[0].max_new_tokens
)
@ -151,10 +153,10 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "TestTestTestTestTestTest"
assert generated_texts[0].output_text == "TestTestTestTestTestTest"
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
assert (
generated_texts[0].tokens
generated_texts[0].generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
@ -170,10 +172,12 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
)
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
assert (
generated_texts[0].tokens
generated_texts[0].generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
)
@ -240,10 +244,10 @@ def test_batch_concatenate(
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "TestTestTestTestTestTest"
assert generated_texts[0].output_text == "TestTestTestTestTestTest"
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
assert (
generated_texts[0].tokens
generated_texts[0].generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
@ -259,10 +263,12 @@ def test_batch_concatenate(
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
)
assert generated_texts[0].request == default_bloom_batch.requests[0]
assert (
generated_texts[0].tokens
generated_texts[0].generated_tokens
== default_bloom_batch.stopping_criterias[0].max_new_tokens
)
@ -279,9 +285,11 @@ def test_batch_concatenate(
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
)
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
assert (
generated_texts[0].tokens
generated_texts[0].generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
)

View File

@ -127,10 +127,11 @@ def test_causal_lm_generate_token_completion(
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
assert (
generated_texts[0].tokens
generated_texts[0].generated_tokens
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
@ -150,12 +151,12 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "Test.java:784)"
assert generated_texts[0].output_text == "Test.java:784)"
assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
)
assert (
generated_texts[0].tokens
generated_texts[0].generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
@ -171,12 +172,12 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
)
assert (
generated_texts[0].tokens
generated_texts[0].generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
@ -242,12 +243,12 @@ def test_batch_concatenate(
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "Test.java:784)"
assert generated_texts[0].output_text == "Test.java:784)"
assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
)
assert (
generated_texts[0].tokens
generated_texts[0].generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
@ -263,10 +264,10 @@ def test_batch_concatenate(
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
assert (
generated_texts[0].tokens
generated_texts[0].generated_tokens
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
@ -283,11 +284,11 @@ def test_batch_concatenate(
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
)
assert (
generated_texts[0].tokens
generated_texts[0].generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
)

View File

@ -148,9 +148,9 @@ def test_seq2seq_lm_generate_token_completion(
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "a few weeks"
assert generated_texts[0].output_text == "a few weeks"
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0]
assert generated_texts[0].tokens == 7
assert generated_texts[0].generated_tokens == 7
def test_seq2seq_lm_generate_token_completion_multi(
@ -166,12 +166,12 @@ def test_seq2seq_lm_generate_token_completion_multi(
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "a few "
assert generated_texts[0].output_text == "a few "
assert (
generated_texts[0].request
== default_multi_requests_seq2seq_lm_batch.requests[1]
)
assert generated_texts[0].tokens == 5
assert generated_texts[0].generated_tokens == 5
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == []
@ -180,12 +180,12 @@ def test_seq2seq_lm_generate_token_completion_multi(
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "a few weeks"
assert generated_texts[0].output_text == "a few weeks"
assert (
generated_texts[0].request
== default_multi_requests_seq2seq_lm_batch.requests[0]
)
assert generated_texts[0].tokens == 7
assert generated_texts[0].generated_tokens == 7
def test_batch_concatenate(
@ -287,28 +287,28 @@ def test_batch_concatenate(
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "a few "
assert generated_texts[0].output_text == "a few "
assert (
generated_texts[0].request
== default_multi_requests_seq2seq_lm_batch.requests[1]
)
assert generated_texts[0].tokens == 5
assert generated_texts[0].generated_tokens == 5
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "a few weeks"
assert generated_texts[0].output_text == "a few weeks"
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0]
assert generated_texts[0].tokens == 7
assert generated_texts[0].generated_tokens == 7
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "a few weeks"
assert generated_texts[0].output_text == "a few weeks"
assert (
generated_texts[0].request
== default_multi_requests_seq2seq_lm_batch.requests[0]
)
assert generated_texts[0].tokens == 7
assert generated_texts[0].generated_tokens == 7

View File

@ -246,12 +246,8 @@ class BLOOMSharded(BLOOM):
)
# Logits are sharded, so we need to gather them
logits_shard = outputs.logits[:, -1, :].contiguous()
batch_size, vocab_shard_size = logits_shard.shape
vocab_size = self.world_size * vocab_shard_size
logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, logits_shard, group=self.process_group)
logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size)
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values

View File

@ -22,6 +22,7 @@ class CausalLMBatch:
# All tokens
all_input_ids: List[torch.Tensor]
all_logprobs: List[Optional[torch.Tensor]]
# Lengths of all generations present in the batch
input_lengths: List[int]
@ -52,6 +53,7 @@ class CausalLMBatch:
next_token_choosers = []
stopping_criterias = []
input_lengths = []
all_logprobs = []
# Parse batch
for r in pb.requests:
@ -61,6 +63,7 @@ class CausalLMBatch:
stopping_criterias.append(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
)
all_logprobs.append(None)
pad_to_multiple_of = 8 if "gpu" in str(device) else None
tokenized_inputs = tokenizer(
@ -78,6 +81,7 @@ class CausalLMBatch:
attention_mask=tokenized_inputs["attention_mask"],
past_key_values=None,
all_input_ids=all_input_ids,
all_logprobs=all_logprobs,
input_lengths=input_lengths,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
@ -95,6 +99,7 @@ class CausalLMBatch:
requests = []
input_lengths = []
all_input_ids = []
all_logprobs = []
next_token_choosers = []
stopping_criterias = []
@ -110,6 +115,7 @@ class CausalLMBatch:
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
all_input_ids.extend(batch.all_input_ids)
all_logprobs.extend(batch.all_logprobs)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
@ -217,6 +223,7 @@ class CausalLMBatch:
attention_mask=attention_mask,
past_key_values=past_key_values,
all_input_ids=all_input_ids,
all_logprobs=all_logprobs,
input_lengths=input_lengths,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
@ -291,6 +298,7 @@ class CausalLM(Model):
next_batch_input_lengths = []
next_batch_input_ids = []
next_batch_all_input_ids = []
next_batch_all_logprobs = []
# Metadata
next_batch_size = 0
@ -307,6 +315,7 @@ class CausalLM(Model):
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
batch.all_logprobs,
)
# For each member of the batch
@ -316,34 +325,59 @@ class CausalLM(Model):
logits,
next_token_chooser,
stopping_criteria,
all_tokens,
all_input_ids,
all_logprobs,
) in enumerate(iterator):
# Select next token
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
tokens, logprobs = next_token_chooser(all_input_ids, logits)
next_token = tokens[-1].view(1, 1)
# Append next token to all tokens
all_tokens = torch.cat([all_tokens, next_token])
all_input_ids = torch.cat([all_input_ids, next_token])
new_input_length = input_length + 1
if all_logprobs is None:
# logprobs of all prompt tokens (except the first one) and the generated token
all_logprobs = logprobs.gather(1, all_input_ids[1:])
else:
# logprob of the generated token
next_token_logprob = logprobs[-1, next_token]
all_logprobs = torch.cat([all_logprobs, next_token_logprob])
# Evaluate stopping criteria
stop, reason = stopping_criteria(all_tokens)
stop, reason = stopping_criteria(all_input_ids)
if stop:
# Decode all tokens
output = self.tokenizer.decode(
all_tokens.squeeze(-1), skip_special_tokens=True
output_text = self.tokenizer.decode(
all_input_ids.squeeze(-1), skip_special_tokens=True
)
# Slice with input_length to remove padding
token_ids = all_input_ids[-new_input_length:]
tokens = self.tokenizer.batch_decode(token_ids)
# Add NaN for the first prompt token
logprobs = [float("nan")] + all_logprobs[-new_input_length:].squeeze(
1
).tolist()
# Add to the list of finished generations with the original request
generated_texts.append(
GeneratedText(
request, output, stopping_criteria.current_tokens, reason
request=request,
output_text=output_text,
generated_tokens=stopping_criteria.current_tokens,
tokens=tokens,
token_ids=token_ids.squeeze(1).tolist(),
logprobs=logprobs,
reason=reason,
)
)
# add to the next batch
else:
next_batch_keep_indices.append(i)
next_batch_input_ids.append(next_token)
next_batch_all_input_ids.append(all_tokens)
next_batch_all_input_ids.append(all_input_ids)
next_batch_all_logprobs.append(all_logprobs)
next_batch_size += 1
new_input_length = input_length + 1
next_batch_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length
@ -397,6 +431,7 @@ class CausalLM(Model):
attention_mask=next_batch_attention_mask,
past_key_values=next_batch_past_key_values,
all_input_ids=next_batch_all_input_ids,
all_logprobs=next_batch_all_logprobs,
input_lengths=next_batch_input_lengths,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,

View File

@ -321,12 +321,8 @@ class GalacticaSharded(Galactica):
)
# Logits are sharded, so we need to gather them
logits_shard = outputs.logits[:, -1, :].contiguous()
batch_size, vocab_shard_size = logits_shard.shape
vocab_size = self.world_size * vocab_shard_size
logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, logits_shard, group=self.process_group)
logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size)
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values

View File

@ -30,6 +30,7 @@ class Seq2SeqLMBatch:
# Lengths of all generations present in the batch
input_lengths: List[int]
decoder_input_lengths: List[int]
decoder_logprobs: List[Optional[torch.Tensor]]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
@ -60,6 +61,7 @@ class Seq2SeqLMBatch:
decoder_input_ids = []
decoder_input_lengths = []
decoder_logprobs = []
# Parse batch
for r in pb.requests:
@ -72,6 +74,7 @@ class Seq2SeqLMBatch:
stopping_criterias.append(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
)
decoder_logprobs.append(None)
# Tokenize batch
pad_to_multiple_of = 8 if "gpu" in str(device) else None
@ -95,6 +98,7 @@ class Seq2SeqLMBatch:
past_key_values=None,
input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths,
decoder_logprobs=decoder_logprobs,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=len(pb.requests),
@ -117,6 +121,7 @@ class Seq2SeqLMBatch:
requests = []
input_lengths = []
decoder_input_lengths = []
decoder_logprobs = []
next_token_choosers = []
stopping_criterias = []
@ -137,6 +142,7 @@ class Seq2SeqLMBatch:
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
decoder_input_lengths.extend(batch.decoder_input_lengths)
decoder_logprobs.extend(batch.decoder_logprobs)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
@ -286,6 +292,7 @@ class Seq2SeqLMBatch:
past_key_values=past_key_values,
input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths,
decoder_logprobs=decoder_logprobs,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
@ -385,6 +392,7 @@ class Seq2SeqLM(Model):
next_batch_input_lengths = []
next_batch_decoder_input_ids = []
next_batch_decoder_input_lengths = []
next_batch_decoder_logprobs = []
# Metadata
next_batch_size = 0
@ -399,6 +407,7 @@ class Seq2SeqLM(Model):
batch.requests,
batch.input_lengths,
batch.decoder_input_lengths,
batch.decoder_logprobs,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
@ -411,38 +420,58 @@ class Seq2SeqLM(Model):
request,
input_length,
decoder_input_length,
decoder_logprobs,
logits,
next_token_chooser,
stopping_criteria,
input_tokens,
decoder_tokens,
decoder_input_ids,
) in enumerate(iterator):
all_tokens = torch.cat([input_tokens, decoder_tokens])
# Select next token
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
next_token, logprobs = next_token_chooser(decoder_input_ids, logits)
# Append next token to decoder tokens
decoder_tokens = torch.cat([decoder_tokens, next_token.squeeze(1)])
decoder_input_ids = torch.cat([decoder_input_ids, next_token])
new_decoder_input_length = decoder_input_length + 1
next_token_logprob = logprobs[-1, next_token]
if decoder_logprobs is None:
decoder_logprobs = next_token_logprob
else:
decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob])
# Evaluate stopping criteria
stop, reason = stopping_criteria(decoder_tokens)
stop, reason = stopping_criteria(decoder_input_ids)
if stop:
# Decode tokens
output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True)
# Slice with decoder_input_length to remove padding
# Decode all tokens
token_ids = decoder_input_ids[-new_decoder_input_length:]
output_text = self.tokenizer.decode(token_ids, skip_special_tokens=True)
tokens = self.tokenizer.batch_decode(token_ids)
# Add NaN for the bos token
logprobs = [float("nan")] + decoder_logprobs[
-new_decoder_input_length:
].tolist()
# Add to the list of finished generations with the original request
generated_texts.append(
GeneratedText(
request, output, stopping_criteria.current_tokens, reason
request=request,
output_text=output_text,
generated_tokens=stopping_criteria.current_tokens,
tokens=tokens,
token_ids=token_ids.tolist(),
logprobs=logprobs,
reason=reason,
)
)
# add to the next batch
else:
next_batch_keep_indices.append(i)
next_batch_decoder_input_ids.append(decoder_tokens.unsqueeze(0))
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
next_batch_size += 1
new_decoder_input_length = decoder_input_length + 1
next_batch_input_lengths.append(input_length)
next_batch_decoder_input_lengths.append(new_decoder_input_length)
next_batch_decoder_logprobs.append(decoder_logprobs)
next_batch_max_input_length = max(
next_batch_max_input_length, input_length
)
@ -515,6 +544,7 @@ class Seq2SeqLM(Model):
past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths,
decoder_input_lengths=next_batch_decoder_input_lengths,
decoder_logprobs=next_batch_decoder_logprobs,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,

View File

@ -30,14 +30,20 @@ class Batch(ABC):
@dataclass
class GeneratedText:
request: generate_pb2.Request
output: str
tokens: int
output_text: str
generated_tokens: int
tokens: List[str]
token_ids: List[int]
logprobs: List[float]
reason: str
def to_pb(self) -> generate_pb2.GeneratedText:
return generate_pb2.GeneratedText(
request=self.request,
output=self.output,
output_text=self.output_text,
generated_tokens=self.generated_tokens,
tokens=self.tokens,
token_ids=self.token_ids,
logprobs=self.logprobs,
finish_reason=self.reason,
)

View File

@ -55,12 +55,16 @@ class NextTokenChooser:
self.choice = Sampling() if sampling else Greedy()
def __call__(self, input_ids, scores):
# Warp logits
scores = self.warpers(input_ids, scores)
# Compute logprobs
logprobs = torch.log_softmax(scores, -1)
# Choose tokens
next_ids = self.choice(scores)
return next_ids.unsqueeze(-1)
return next_ids, logprobs
@classmethod
def from_pb(cls, pb: generate_pb2.LogitsWarperParameters) -> "NextTokenChooser":
def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChooser":
return NextTokenChooser(
temperature=pb.temperature,
top_k=pb.top_k,