feat: Return logprobs (#8)
This commit is contained in:
parent
718096f695
commit
32a253063d
|
@ -0,0 +1,30 @@
|
|||
name: Server Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "server/**"
|
||||
- "proto/**"
|
||||
|
||||
jobs:
|
||||
run_tests:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: 3.9
|
||||
- name: Loading cache.
|
||||
uses: actions/cache@v2
|
||||
id: model_cache
|
||||
with:
|
||||
path: ~/.cache/huggingface/
|
||||
key: models
|
||||
- name: Install server dependencies
|
||||
run: |
|
||||
make install-server
|
||||
- name: Run tests
|
||||
run: |
|
||||
pip install pytest
|
||||
pytest -sv server/tests
|
|
@ -17,6 +17,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
|
|||
- 45ms per token generation for BLOOM with 8xA100 80GB
|
||||
- Logits warpers (temperature scaling, topk ...)
|
||||
- Stop sequences
|
||||
- Log probabilities
|
||||
|
||||
## Officially supported models
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ message ClearCacheRequest {}
|
|||
/// Empty response
|
||||
message ClearCacheResponse {}
|
||||
|
||||
message LogitsWarperParameters {
|
||||
message NextTokenChooserParameters {
|
||||
/// exponential scaling output probability distribution
|
||||
float temperature = 1;
|
||||
/// restricting to the k highest probability elements
|
||||
|
@ -52,8 +52,8 @@ message Request {
|
|||
string inputs = 2;
|
||||
/// The number of tokens inside inputs
|
||||
uint32 input_length = 3;
|
||||
/// Logits Warper Parameters
|
||||
LogitsWarperParameters parameters = 4;
|
||||
/// Next Token Chooser Parameters
|
||||
NextTokenChooserParameters parameters = 4;
|
||||
/// Stopping Criteria Parameters
|
||||
StoppingCriteriaParameters stopping_parameters = 5;
|
||||
}
|
||||
|
@ -71,11 +71,17 @@ message GeneratedText {
|
|||
/// Request
|
||||
Request request = 1;
|
||||
/// Output
|
||||
string output = 2;
|
||||
string output_text = 2;
|
||||
/// Number of generated tokens
|
||||
uint32 tokens = 3;
|
||||
uint32 generated_tokens = 3;
|
||||
/// Tokens
|
||||
repeated string tokens = 4;
|
||||
/// Token IDs
|
||||
repeated uint32 token_ids = 5;
|
||||
/// Logprobs
|
||||
repeated float logprobs = 6;
|
||||
/// Finish reason
|
||||
string finish_reason = 4;
|
||||
string finish_reason = 7;
|
||||
}
|
||||
|
||||
message GenerateRequest {
|
||||
|
|
|
@ -7,7 +7,7 @@ mod sharded_client;
|
|||
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v1::{
|
||||
Batch, GeneratedText, LogitsWarperParameters, Request, StoppingCriteriaParameters,
|
||||
Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
use thiserror::Error;
|
||||
|
|
|
@ -187,9 +187,13 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
|||
let entry = db
|
||||
.remove(&output.request.unwrap().id)
|
||||
.expect("ID not found in db. This is a bug.");
|
||||
|
||||
let response = InferResponse {
|
||||
output: output.output,
|
||||
output_text: output.output_text,
|
||||
generated_tokens: output.generated_tokens,
|
||||
token_ids: output.token_ids,
|
||||
tokens: output.tokens,
|
||||
logprobs: output.logprobs,
|
||||
finish_reason: output.finish_reason,
|
||||
queued: entry.time,
|
||||
start: entry.batch_time.unwrap(), // unwrap is always valid
|
||||
|
@ -202,8 +206,11 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
|||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct InferResponse {
|
||||
pub(crate) output: String,
|
||||
pub(crate) tokens: u32,
|
||||
pub(crate) output_text: String,
|
||||
pub(crate) generated_tokens: u32,
|
||||
pub(crate) token_ids: Vec<u32>,
|
||||
pub(crate) tokens: Vec<String>,
|
||||
pub(crate) logprobs: Vec<f32>,
|
||||
pub(crate) finish_reason: String,
|
||||
pub(crate) queued: Instant,
|
||||
pub(crate) start: Instant,
|
||||
|
|
|
@ -5,7 +5,7 @@ use parking_lot::Mutex;
|
|||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{
|
||||
Batch, ClientError, LogitsWarperParameters, Request, StoppingCriteriaParameters,
|
||||
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use tokio::sync::oneshot::Sender;
|
||||
use tokio::time::Instant;
|
||||
|
@ -71,7 +71,7 @@ impl State {
|
|||
id: *id,
|
||||
inputs: entry.request.inputs.clone(),
|
||||
input_length: entry.input_length as u32,
|
||||
parameters: Some(LogitsWarperParameters::from(
|
||||
parameters: Some(NextTokenChooserParameters::from(
|
||||
entry.request.parameters.clone(),
|
||||
)),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters::from(
|
||||
|
@ -162,7 +162,7 @@ impl Db {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<GenerateParameters> for LogitsWarperParameters {
|
||||
impl From<GenerateParameters> for NextTokenChooserParameters {
|
||||
fn from(parameters: GenerateParameters) -> Self {
|
||||
Self {
|
||||
temperature: parameters.temperature,
|
||||
|
|
|
@ -21,7 +21,10 @@ pub(crate) struct GenerateParameters {
|
|||
pub do_sample: bool,
|
||||
#[serde(default = "default_max_new_tokens")]
|
||||
pub max_new_tokens: u32,
|
||||
#[serde(default)]
|
||||
pub stop: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub details: bool,
|
||||
}
|
||||
|
||||
fn default_temperature() -> f32 {
|
||||
|
@ -52,6 +55,7 @@ fn default_parameters() -> GenerateParameters {
|
|||
do_sample: default_do_sample(),
|
||||
max_new_tokens: default_max_new_tokens(),
|
||||
stop: vec![],
|
||||
details: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -62,10 +66,18 @@ pub(crate) struct GenerateRequest {
|
|||
pub parameters: GenerateParameters,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub(crate) struct Details {
|
||||
pub finish_reason: String,
|
||||
pub generated_tokens: u32,
|
||||
pub tokens: Vec<(u32, String, f32)>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub(crate) struct GeneratedText {
|
||||
pub generated_text: String,
|
||||
pub finish_reason: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub details: Option<Details>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{
|
||||
Batcher, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation,
|
||||
Batcher, Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation,
|
||||
};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, StatusCode};
|
||||
|
@ -54,6 +54,7 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
|
|||
do_sample: false,
|
||||
max_new_tokens: 1,
|
||||
stop: vec![],
|
||||
details: false,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
@ -89,6 +90,7 @@ async fn generate(
|
|||
})?;
|
||||
|
||||
// Validate request
|
||||
let details = req.0.parameters.details;
|
||||
let (input_length, validated_request) =
|
||||
state.validation.validate(req.0).await.map_err(|err| {
|
||||
tracing::error!("{}", err.to_string());
|
||||
|
@ -105,12 +107,31 @@ async fn generate(
|
|||
err
|
||||
})?;
|
||||
|
||||
// Token details
|
||||
let details = match details {
|
||||
true => {
|
||||
let tokens = response
|
||||
.token_ids
|
||||
.into_iter()
|
||||
.zip(response.tokens.into_iter())
|
||||
.zip(response.logprobs.into_iter())
|
||||
.map(|((id, text), logprob)| (id, text, logprob))
|
||||
.collect();
|
||||
Some(Details {
|
||||
finish_reason: response.finish_reason,
|
||||
generated_tokens: response.generated_tokens,
|
||||
tokens,
|
||||
})
|
||||
}
|
||||
false => None,
|
||||
};
|
||||
|
||||
// Timings
|
||||
let total_time = start_time.elapsed();
|
||||
let validation_time = response.queued - start_time;
|
||||
let queue_time = response.start - response.queued;
|
||||
let inference_time = response.end - response.start;
|
||||
let time_per_token = inference_time / response.tokens;
|
||||
let time_per_token = inference_time / response.generated_tokens;
|
||||
|
||||
// Headers
|
||||
let mut headers = HeaderMap::new();
|
||||
|
@ -141,12 +162,12 @@ async fn generate(
|
|||
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
|
||||
tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
|
||||
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
|
||||
tracing::info!("Output: {}", response.output);
|
||||
tracing::info!("Output: {}", response.output_text);
|
||||
|
||||
// Send response
|
||||
let response = vec![GeneratedText {
|
||||
generated_text: response.output,
|
||||
finish_reason: response.finish_reason,
|
||||
generated_text: response.output_text,
|
||||
details,
|
||||
}];
|
||||
Ok((headers, Json(response)))
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ from text_generation.pb import generate_pb2
|
|||
|
||||
@pytest.fixture
|
||||
def default_pb_parameters():
|
||||
return generate_pb2.LogitsWarperParameters(
|
||||
return generate_pb2.NextTokenChooserParameters(
|
||||
temperature=1.0,
|
||||
top_k=0,
|
||||
top_p=1.0,
|
||||
|
|
|
@ -128,10 +128,12 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
|
|||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
assert (
|
||||
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert generated_texts[0].request == default_bloom_batch.requests[0]
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
generated_texts[0].generated_tokens
|
||||
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
|
@ -151,10 +153,10 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "TestTestTestTestTestTest"
|
||||
assert generated_texts[0].output_text == "TestTestTestTestTestTest"
|
||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
generated_texts[0].generated_tokens
|
||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
|
@ -170,10 +172,12 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
assert (
|
||||
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
generated_texts[0].generated_tokens
|
||||
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
|
@ -240,10 +244,10 @@ def test_batch_concatenate(
|
|||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "TestTestTestTestTestTest"
|
||||
assert generated_texts[0].output_text == "TestTestTestTestTestTest"
|
||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
generated_texts[0].generated_tokens
|
||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
|
@ -259,10 +263,12 @@ def test_batch_concatenate(
|
|||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
assert (
|
||||
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert generated_texts[0].request == default_bloom_batch.requests[0]
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
generated_texts[0].generated_tokens
|
||||
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
|
@ -279,9 +285,11 @@ def test_batch_concatenate(
|
|||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
assert (
|
||||
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
generated_texts[0].generated_tokens
|
||||
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
|
|
@ -127,10 +127,11 @@ def test_causal_lm_generate_token_completion(
|
|||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
|
||||
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
|
||||
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
|
||||
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
generated_texts[0].generated_tokens
|
||||
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
|
@ -150,12 +151,12 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "Test.java:784)"
|
||||
assert generated_texts[0].output_text == "Test.java:784)"
|
||||
assert (
|
||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
|
||||
)
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
generated_texts[0].generated_tokens
|
||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
|
@ -171,12 +172,12 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
|
||||
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
|
||||
assert (
|
||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
|
||||
)
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
generated_texts[0].generated_tokens
|
||||
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
|
@ -242,12 +243,12 @@ def test_batch_concatenate(
|
|||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "Test.java:784)"
|
||||
assert generated_texts[0].output_text == "Test.java:784)"
|
||||
assert (
|
||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
|
||||
)
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
generated_texts[0].generated_tokens
|
||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
|
@ -263,10 +264,10 @@ def test_batch_concatenate(
|
|||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
|
||||
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
|
||||
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
generated_texts[0].generated_tokens
|
||||
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
|
@ -283,11 +284,11 @@ def test_batch_concatenate(
|
|||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
|
||||
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
|
||||
assert (
|
||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
|
||||
)
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
generated_texts[0].generated_tokens
|
||||
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
|
|
@ -148,9 +148,9 @@ def test_seq2seq_lm_generate_token_completion(
|
|||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "a few weeks"
|
||||
assert generated_texts[0].output_text == "a few weeks"
|
||||
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0]
|
||||
assert generated_texts[0].tokens == 7
|
||||
assert generated_texts[0].generated_tokens == 7
|
||||
|
||||
|
||||
def test_seq2seq_lm_generate_token_completion_multi(
|
||||
|
@ -166,12 +166,12 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
|||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "a few "
|
||||
assert generated_texts[0].output_text == "a few "
|
||||
assert (
|
||||
generated_texts[0].request
|
||||
== default_multi_requests_seq2seq_lm_batch.requests[1]
|
||||
)
|
||||
assert generated_texts[0].tokens == 5
|
||||
assert generated_texts[0].generated_tokens == 5
|
||||
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
|
@ -180,12 +180,12 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
|||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "a few weeks"
|
||||
assert generated_texts[0].output_text == "a few weeks"
|
||||
assert (
|
||||
generated_texts[0].request
|
||||
== default_multi_requests_seq2seq_lm_batch.requests[0]
|
||||
)
|
||||
assert generated_texts[0].tokens == 7
|
||||
assert generated_texts[0].generated_tokens == 7
|
||||
|
||||
|
||||
def test_batch_concatenate(
|
||||
|
@ -287,28 +287,28 @@ def test_batch_concatenate(
|
|||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "a few "
|
||||
assert generated_texts[0].output_text == "a few "
|
||||
assert (
|
||||
generated_texts[0].request
|
||||
== default_multi_requests_seq2seq_lm_batch.requests[1]
|
||||
)
|
||||
assert generated_texts[0].tokens == 5
|
||||
assert generated_texts[0].generated_tokens == 5
|
||||
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "a few weeks"
|
||||
assert generated_texts[0].output_text == "a few weeks"
|
||||
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0]
|
||||
assert generated_texts[0].tokens == 7
|
||||
assert generated_texts[0].generated_tokens == 7
|
||||
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "a few weeks"
|
||||
assert generated_texts[0].output_text == "a few weeks"
|
||||
assert (
|
||||
generated_texts[0].request
|
||||
== default_multi_requests_seq2seq_lm_batch.requests[0]
|
||||
)
|
||||
assert generated_texts[0].tokens == 7
|
||||
assert generated_texts[0].generated_tokens == 7
|
||||
|
|
|
@ -246,12 +246,8 @@ class BLOOMSharded(BLOOM):
|
|||
)
|
||||
|
||||
# Logits are sharded, so we need to gather them
|
||||
logits_shard = outputs.logits[:, -1, :].contiguous()
|
||||
|
||||
batch_size, vocab_shard_size = logits_shard.shape
|
||||
vocab_size = self.world_size * vocab_shard_size
|
||||
logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(logits, logits_shard, group=self.process_group)
|
||||
logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size)
|
||||
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
|
||||
logits = torch.cat(logits, dim=2)
|
||||
|
||||
return logits, outputs.past_key_values
|
||||
|
|
|
@ -22,6 +22,7 @@ class CausalLMBatch:
|
|||
|
||||
# All tokens
|
||||
all_input_ids: List[torch.Tensor]
|
||||
all_logprobs: List[Optional[torch.Tensor]]
|
||||
|
||||
# Lengths of all generations present in the batch
|
||||
input_lengths: List[int]
|
||||
|
@ -52,6 +53,7 @@ class CausalLMBatch:
|
|||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
input_lengths = []
|
||||
all_logprobs = []
|
||||
|
||||
# Parse batch
|
||||
for r in pb.requests:
|
||||
|
@ -61,6 +63,7 @@ class CausalLMBatch:
|
|||
stopping_criterias.append(
|
||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||
)
|
||||
all_logprobs.append(None)
|
||||
|
||||
pad_to_multiple_of = 8 if "gpu" in str(device) else None
|
||||
tokenized_inputs = tokenizer(
|
||||
|
@ -78,6 +81,7 @@ class CausalLMBatch:
|
|||
attention_mask=tokenized_inputs["attention_mask"],
|
||||
past_key_values=None,
|
||||
all_input_ids=all_input_ids,
|
||||
all_logprobs=all_logprobs,
|
||||
input_lengths=input_lengths,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
|
@ -95,6 +99,7 @@ class CausalLMBatch:
|
|||
requests = []
|
||||
input_lengths = []
|
||||
all_input_ids = []
|
||||
all_logprobs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
|
@ -110,6 +115,7 @@ class CausalLMBatch:
|
|||
requests.extend(batch.requests)
|
||||
input_lengths.extend(batch.input_lengths)
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
all_logprobs.extend(batch.all_logprobs)
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
|
||||
|
@ -217,6 +223,7 @@ class CausalLMBatch:
|
|||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
all_input_ids=all_input_ids,
|
||||
all_logprobs=all_logprobs,
|
||||
input_lengths=input_lengths,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
|
@ -291,6 +298,7 @@ class CausalLM(Model):
|
|||
next_batch_input_lengths = []
|
||||
next_batch_input_ids = []
|
||||
next_batch_all_input_ids = []
|
||||
next_batch_all_logprobs = []
|
||||
|
||||
# Metadata
|
||||
next_batch_size = 0
|
||||
|
@ -307,6 +315,7 @@ class CausalLM(Model):
|
|||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
batch.all_logprobs,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
|
@ -316,34 +325,59 @@ class CausalLM(Model):
|
|||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_tokens,
|
||||
all_input_ids,
|
||||
all_logprobs,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
|
||||
tokens, logprobs = next_token_chooser(all_input_ids, logits)
|
||||
next_token = tokens[-1].view(1, 1)
|
||||
|
||||
# Append next token to all tokens
|
||||
all_tokens = torch.cat([all_tokens, next_token])
|
||||
all_input_ids = torch.cat([all_input_ids, next_token])
|
||||
new_input_length = input_length + 1
|
||||
|
||||
if all_logprobs is None:
|
||||
# logprobs of all prompt tokens (except the first one) and the generated token
|
||||
all_logprobs = logprobs.gather(1, all_input_ids[1:])
|
||||
else:
|
||||
# logprob of the generated token
|
||||
next_token_logprob = logprobs[-1, next_token]
|
||||
all_logprobs = torch.cat([all_logprobs, next_token_logprob])
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(all_tokens)
|
||||
stop, reason = stopping_criteria(all_input_ids)
|
||||
if stop:
|
||||
# Decode all tokens
|
||||
output = self.tokenizer.decode(
|
||||
all_tokens.squeeze(-1), skip_special_tokens=True
|
||||
output_text = self.tokenizer.decode(
|
||||
all_input_ids.squeeze(-1), skip_special_tokens=True
|
||||
)
|
||||
# Slice with input_length to remove padding
|
||||
token_ids = all_input_ids[-new_input_length:]
|
||||
tokens = self.tokenizer.batch_decode(token_ids)
|
||||
# Add NaN for the first prompt token
|
||||
logprobs = [float("nan")] + all_logprobs[-new_input_length:].squeeze(
|
||||
1
|
||||
).tolist()
|
||||
|
||||
# Add to the list of finished generations with the original request
|
||||
generated_texts.append(
|
||||
GeneratedText(
|
||||
request, output, stopping_criteria.current_tokens, reason
|
||||
request=request,
|
||||
output_text=output_text,
|
||||
generated_tokens=stopping_criteria.current_tokens,
|
||||
tokens=tokens,
|
||||
token_ids=token_ids.squeeze(1).tolist(),
|
||||
logprobs=logprobs,
|
||||
reason=reason,
|
||||
)
|
||||
)
|
||||
# add to the next batch
|
||||
else:
|
||||
next_batch_keep_indices.append(i)
|
||||
next_batch_input_ids.append(next_token)
|
||||
next_batch_all_input_ids.append(all_tokens)
|
||||
next_batch_all_input_ids.append(all_input_ids)
|
||||
next_batch_all_logprobs.append(all_logprobs)
|
||||
next_batch_size += 1
|
||||
new_input_length = input_length + 1
|
||||
next_batch_input_lengths.append(new_input_length)
|
||||
next_batch_max_sequence_length = max(
|
||||
next_batch_max_sequence_length, new_input_length
|
||||
|
@ -397,6 +431,7 @@ class CausalLM(Model):
|
|||
attention_mask=next_batch_attention_mask,
|
||||
past_key_values=next_batch_past_key_values,
|
||||
all_input_ids=next_batch_all_input_ids,
|
||||
all_logprobs=next_batch_all_logprobs,
|
||||
input_lengths=next_batch_input_lengths,
|
||||
next_token_choosers=next_batch_next_token_choosers,
|
||||
stopping_criterias=next_batch_stopping_criterias,
|
||||
|
|
|
@ -321,12 +321,8 @@ class GalacticaSharded(Galactica):
|
|||
)
|
||||
|
||||
# Logits are sharded, so we need to gather them
|
||||
logits_shard = outputs.logits[:, -1, :].contiguous()
|
||||
|
||||
batch_size, vocab_shard_size = logits_shard.shape
|
||||
vocab_size = self.world_size * vocab_shard_size
|
||||
logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(logits, logits_shard, group=self.process_group)
|
||||
logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size)
|
||||
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
|
||||
logits = torch.cat(logits, dim=2)
|
||||
|
||||
return logits, outputs.past_key_values
|
||||
|
|
|
@ -30,6 +30,7 @@ class Seq2SeqLMBatch:
|
|||
# Lengths of all generations present in the batch
|
||||
input_lengths: List[int]
|
||||
decoder_input_lengths: List[int]
|
||||
decoder_logprobs: List[Optional[torch.Tensor]]
|
||||
|
||||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
|
@ -60,6 +61,7 @@ class Seq2SeqLMBatch:
|
|||
|
||||
decoder_input_ids = []
|
||||
decoder_input_lengths = []
|
||||
decoder_logprobs = []
|
||||
|
||||
# Parse batch
|
||||
for r in pb.requests:
|
||||
|
@ -72,6 +74,7 @@ class Seq2SeqLMBatch:
|
|||
stopping_criterias.append(
|
||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||
)
|
||||
decoder_logprobs.append(None)
|
||||
|
||||
# Tokenize batch
|
||||
pad_to_multiple_of = 8 if "gpu" in str(device) else None
|
||||
|
@ -95,6 +98,7 @@ class Seq2SeqLMBatch:
|
|||
past_key_values=None,
|
||||
input_lengths=input_lengths,
|
||||
decoder_input_lengths=decoder_input_lengths,
|
||||
decoder_logprobs=decoder_logprobs,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=len(pb.requests),
|
||||
|
@ -117,6 +121,7 @@ class Seq2SeqLMBatch:
|
|||
requests = []
|
||||
input_lengths = []
|
||||
decoder_input_lengths = []
|
||||
decoder_logprobs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
|
@ -137,6 +142,7 @@ class Seq2SeqLMBatch:
|
|||
requests.extend(batch.requests)
|
||||
input_lengths.extend(batch.input_lengths)
|
||||
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
||||
decoder_logprobs.extend(batch.decoder_logprobs)
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
|
||||
|
@ -286,6 +292,7 @@ class Seq2SeqLMBatch:
|
|||
past_key_values=past_key_values,
|
||||
input_lengths=input_lengths,
|
||||
decoder_input_lengths=decoder_input_lengths,
|
||||
decoder_logprobs=decoder_logprobs,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=total_batch_size,
|
||||
|
@ -385,6 +392,7 @@ class Seq2SeqLM(Model):
|
|||
next_batch_input_lengths = []
|
||||
next_batch_decoder_input_ids = []
|
||||
next_batch_decoder_input_lengths = []
|
||||
next_batch_decoder_logprobs = []
|
||||
|
||||
# Metadata
|
||||
next_batch_size = 0
|
||||
|
@ -399,6 +407,7 @@ class Seq2SeqLM(Model):
|
|||
batch.requests,
|
||||
batch.input_lengths,
|
||||
batch.decoder_input_lengths,
|
||||
batch.decoder_logprobs,
|
||||
logits,
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
|
@ -411,38 +420,58 @@ class Seq2SeqLM(Model):
|
|||
request,
|
||||
input_length,
|
||||
decoder_input_length,
|
||||
decoder_logprobs,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
input_tokens,
|
||||
decoder_tokens,
|
||||
decoder_input_ids,
|
||||
) in enumerate(iterator):
|
||||
all_tokens = torch.cat([input_tokens, decoder_tokens])
|
||||
# Select next token
|
||||
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
|
||||
next_token, logprobs = next_token_chooser(decoder_input_ids, logits)
|
||||
|
||||
# Append next token to decoder tokens
|
||||
decoder_tokens = torch.cat([decoder_tokens, next_token.squeeze(1)])
|
||||
decoder_input_ids = torch.cat([decoder_input_ids, next_token])
|
||||
new_decoder_input_length = decoder_input_length + 1
|
||||
|
||||
next_token_logprob = logprobs[-1, next_token]
|
||||
if decoder_logprobs is None:
|
||||
decoder_logprobs = next_token_logprob
|
||||
else:
|
||||
decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob])
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(decoder_tokens)
|
||||
stop, reason = stopping_criteria(decoder_input_ids)
|
||||
if stop:
|
||||
# Decode tokens
|
||||
output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True)
|
||||
# Slice with decoder_input_length to remove padding
|
||||
# Decode all tokens
|
||||
token_ids = decoder_input_ids[-new_decoder_input_length:]
|
||||
output_text = self.tokenizer.decode(token_ids, skip_special_tokens=True)
|
||||
tokens = self.tokenizer.batch_decode(token_ids)
|
||||
# Add NaN for the bos token
|
||||
logprobs = [float("nan")] + decoder_logprobs[
|
||||
-new_decoder_input_length:
|
||||
].tolist()
|
||||
# Add to the list of finished generations with the original request
|
||||
generated_texts.append(
|
||||
GeneratedText(
|
||||
request, output, stopping_criteria.current_tokens, reason
|
||||
request=request,
|
||||
output_text=output_text,
|
||||
generated_tokens=stopping_criteria.current_tokens,
|
||||
tokens=tokens,
|
||||
token_ids=token_ids.tolist(),
|
||||
logprobs=logprobs,
|
||||
reason=reason,
|
||||
)
|
||||
)
|
||||
# add to the next batch
|
||||
else:
|
||||
next_batch_keep_indices.append(i)
|
||||
next_batch_decoder_input_ids.append(decoder_tokens.unsqueeze(0))
|
||||
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
|
||||
next_batch_size += 1
|
||||
new_decoder_input_length = decoder_input_length + 1
|
||||
next_batch_input_lengths.append(input_length)
|
||||
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
||||
next_batch_decoder_logprobs.append(decoder_logprobs)
|
||||
next_batch_max_input_length = max(
|
||||
next_batch_max_input_length, input_length
|
||||
)
|
||||
|
@ -515,6 +544,7 @@ class Seq2SeqLM(Model):
|
|||
past_key_values=next_batch_past_key_values,
|
||||
input_lengths=next_batch_input_lengths,
|
||||
decoder_input_lengths=next_batch_decoder_input_lengths,
|
||||
decoder_logprobs=next_batch_decoder_logprobs,
|
||||
next_token_choosers=next_batch_next_token_choosers,
|
||||
stopping_criterias=next_batch_stopping_criterias,
|
||||
size=next_batch_size,
|
||||
|
|
|
@ -30,14 +30,20 @@ class Batch(ABC):
|
|||
@dataclass
|
||||
class GeneratedText:
|
||||
request: generate_pb2.Request
|
||||
output: str
|
||||
tokens: int
|
||||
output_text: str
|
||||
generated_tokens: int
|
||||
tokens: List[str]
|
||||
token_ids: List[int]
|
||||
logprobs: List[float]
|
||||
reason: str
|
||||
|
||||
def to_pb(self) -> generate_pb2.GeneratedText:
|
||||
return generate_pb2.GeneratedText(
|
||||
request=self.request,
|
||||
output=self.output,
|
||||
output_text=self.output_text,
|
||||
generated_tokens=self.generated_tokens,
|
||||
tokens=self.tokens,
|
||||
token_ids=self.token_ids,
|
||||
logprobs=self.logprobs,
|
||||
finish_reason=self.reason,
|
||||
)
|
||||
|
|
|
@ -55,12 +55,16 @@ class NextTokenChooser:
|
|||
self.choice = Sampling() if sampling else Greedy()
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
# Warp logits
|
||||
scores = self.warpers(input_ids, scores)
|
||||
# Compute logprobs
|
||||
logprobs = torch.log_softmax(scores, -1)
|
||||
# Choose tokens
|
||||
next_ids = self.choice(scores)
|
||||
return next_ids.unsqueeze(-1)
|
||||
return next_ids, logprobs
|
||||
|
||||
@classmethod
|
||||
def from_pb(cls, pb: generate_pb2.LogitsWarperParameters) -> "NextTokenChooser":
|
||||
def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChooser":
|
||||
return NextTokenChooser(
|
||||
temperature=pb.temperature,
|
||||
top_k=pb.top_k,
|
||||
|
|
Loading…
Reference in New Issue