feat: Support stop sequences (#7)

This commit is contained in:
OlivierDehaene 2022-12-12 18:25:22 +01:00 committed by GitHub
parent 042180d88f
commit 718096f695
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 254 additions and 107 deletions

View File

@ -15,6 +15,8 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
- 45ms per token generation for BLOOM with 8xA100 80GB
- Logits warpers (temperature scaling, topk ...)
- Stop sequences
## Officially supported models

View File

@ -28,12 +28,23 @@ message ClearCacheRequest {}
message ClearCacheResponse {}
message LogitsWarperParameters {
/// exponential scaling output probability distribution
float temperature = 1;
/// restricting to the k highest probability elements
uint32 top_k = 2;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float top_p = 3;
/// apply sampling on the logits
bool do_sample = 4;
}
message StoppingCriteriaParameters {
/// Maximum number of generated tokens
uint32 max_new_tokens = 1;
/// Optional stopping sequences
repeated string stop_sequences = 2;
}
message Request {
/// Request ID
uint64 id = 1;
@ -43,8 +54,8 @@ message Request {
uint32 input_length = 3;
/// Logits Warper Parameters
LogitsWarperParameters parameters = 4;
/// Stopping criteria
uint32 max_new_tokens = 5;
/// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 5;
}
message Batch {
@ -63,6 +74,8 @@ message GeneratedText {
string output = 2;
/// Number of generated tokens
uint32 tokens = 3;
/// Finish reason
string finish_reason = 4;
}
message GenerateRequest {

View File

@ -6,7 +6,9 @@ mod pb;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request};
pub use pb::generate::v1::{
Batch, GeneratedText, LogitsWarperParameters, Request, StoppingCriteriaParameters,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;
use tonic::transport;

View File

@ -190,6 +190,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
let response = InferResponse {
output: output.output,
tokens: output.tokens,
finish_reason: output.finish_reason,
queued: entry.time,
start: entry.batch_time.unwrap(), // unwrap is always valid
end: Instant::now(),
@ -203,6 +204,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
pub(crate) struct InferResponse {
pub(crate) output: String,
pub(crate) tokens: u32,
pub(crate) finish_reason: String,
pub(crate) queued: Instant,
pub(crate) start: Instant,
pub(crate) end: Instant,

View File

@ -4,7 +4,9 @@ use crate::{GenerateParameters, GenerateRequest};
use parking_lot::Mutex;
use std::collections::BTreeMap;
use std::sync::Arc;
use text_generation_client::{Batch, ClientError, LogitsWarperParameters, Request};
use text_generation_client::{
Batch, ClientError, LogitsWarperParameters, Request, StoppingCriteriaParameters,
};
use tokio::sync::oneshot::Sender;
use tokio::time::Instant;
@ -72,7 +74,9 @@ impl State {
parameters: Some(LogitsWarperParameters::from(
entry.request.parameters.clone(),
)),
max_new_tokens: entry.request.parameters.max_new_tokens,
stopping_parameters: Some(StoppingCriteriaParameters::from(
entry.request.parameters.clone(),
)),
});
ids.push(*id);
@ -168,3 +172,12 @@ impl From<GenerateParameters> for LogitsWarperParameters {
}
}
}
impl From<GenerateParameters> for StoppingCriteriaParameters {
fn from(parameters: GenerateParameters) -> Self {
Self {
stop_sequences: parameters.stop,
max_new_tokens: parameters.max_new_tokens,
}
}
}

View File

@ -21,6 +21,7 @@ pub(crate) struct GenerateParameters {
pub do_sample: bool,
#[serde(default = "default_max_new_tokens")]
pub max_new_tokens: u32,
pub stop: Vec<String>,
}
fn default_temperature() -> f32 {
@ -50,6 +51,7 @@ fn default_parameters() -> GenerateParameters {
top_p: default_top_p(),
do_sample: default_do_sample(),
max_new_tokens: default_max_new_tokens(),
stop: vec![],
}
}
@ -63,6 +65,7 @@ pub(crate) struct GenerateRequest {
#[derive(Serialize)]
pub(crate) struct GeneratedText {
pub generated_text: String,
pub finish_reason: String,
}
#[derive(Serialize)]

View File

@ -53,6 +53,7 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
top_p: 1.0,
do_sample: false,
max_new_tokens: 1,
stop: vec![],
},
},
)
@ -88,11 +89,8 @@ async fn generate(
})?;
// Validate request
let (input_length, validated_request) = state
.validation
.validate(req.0)
.await
.map_err(|err| {
let (input_length, validated_request) =
state.validation.validate(req.0).await.map_err(|err| {
tracing::error!("{}", err.to_string());
err
})?;
@ -148,6 +146,7 @@ async fn generate(
// Send response
let response = vec![GeneratedText {
generated_text: response.output,
finish_reason: response.finish_reason,
}];
Ok((headers, Json(response)))
}

View File

@ -121,6 +121,14 @@ fn validation_worker(
.unwrap_or(());
continue;
}
if request.parameters.stop.len() > 4 {
response_tx
.send(Err(ValidationError::StopSequence(
request.parameters.stop.len(),
)))
.unwrap_or(());
continue;
}
// Get the number of tokens in the input
match tokenizer.encode(request.inputs.clone(), false) {
@ -163,6 +171,8 @@ pub enum ValidationError {
MaxNewTokens,
#[error("inputs must have less than {1} tokens. Given: {0}")]
InputLength(usize, usize),
#[error("stop supports up to 4 stop sequences. Given: {0}")]
StopSequence(usize),
#[error("tokenizer error {0}")]
Tokenizer(String),
}

View File

@ -15,6 +15,11 @@ def default_pb_parameters():
)
@pytest.fixture
def default_pb_stop_parameters():
return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10)
@pytest.fixture(scope="session")
def bloom_560m_tokenizer():
return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left")

View File

@ -9,13 +9,13 @@ from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
@pytest.fixture
def default_pb_request(default_pb_parameters):
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_length=1,
parameters=default_pb_parameters,
max_new_tokens=10,
stopping_parameters=default_pb_stop_parameters,
)
@ -36,7 +36,7 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer)
req_0 = copy(default_pb_request)
req_1 = default_pb_request
req_1.id = 1
req_1.max_new_tokens = 5
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
return BloomCausalLMBatch.from_pb(
@ -56,7 +56,6 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
assert batch.requests == default_pb_batch.requests
assert len(batch.input_ids) == default_pb_batch.size
assert len(batch.input_ids[0]) == 8
assert batch.input_ids[0][-1] == 10264
assert torch.all(batch.input_ids[0][:-1] == 3)
@ -85,6 +84,7 @@ def test_causal_lm_batch_type(default_bloom):
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
sequence_length = len(default_bloom_batch.all_input_ids[0])
generated_texts, next_batch = default_bloom.generate_token(default_bloom_batch)
assert generated_texts == []
@ -92,7 +92,11 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert not next_batch.keys_head_dim_last
assert len(next_batch.all_input_ids) == next_batch.size
assert len(next_batch.all_input_ids[0]) == len(next_batch.attention_mask[0]) == 9
assert (
len(next_batch.all_input_ids[0])
== len(next_batch.attention_mask[0])
== sequence_length + 1
)
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
assert torch.all(next_batch.all_input_ids[0][:-2] == 3)
@ -106,8 +110,12 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None
assert all([p[0].shape == (16, 64, 8) for p in next_batch.past_key_values])
assert all([p[1].shape == (16, 8, 64) for p in next_batch.past_key_values])
assert all(
[p[0].shape == (16, 64, sequence_length) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values]
)
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):

View File

@ -8,13 +8,13 @@ from text_generation.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture
def default_pb_request(default_pb_parameters):
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_length=1,
parameters=default_pb_parameters,
max_new_tokens=10,
stopping_parameters=default_pb_stop_parameters,
)
@ -33,7 +33,7 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
req_0 = copy(default_pb_request)
req_1 = default_pb_request
req_1.id = 1
req_1.max_new_tokens = 5
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
@ -51,7 +51,6 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
assert batch.requests == default_pb_batch.requests
assert len(batch.input_ids) == default_pb_batch.size
assert len(batch.input_ids[0]) == 8
assert batch.input_ids[0][-1] == 14402
assert torch.all(batch.input_ids[0][:-1] == 50256)
@ -80,6 +79,7 @@ def test_causal_lm_batch_type(default_causal_lm):
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
generated_texts, next_batch = default_causal_lm.generate_token(
default_causal_lm_batch
)
@ -88,8 +88,12 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert isinstance(next_batch, CausalLMBatch)
assert len(next_batch.all_input_ids) == next_batch.size
assert len(next_batch.all_input_ids[0]) == len(next_batch.attention_mask[0]) == 9
assert next_batch.all_input_ids[0][-1] == 6208
assert (
len(next_batch.all_input_ids[0])
== len(next_batch.attention_mask[0])
== sequence_length + 1
)
assert next_batch.all_input_ids[0][-1] == 13
assert next_batch.all_input_ids[0][-2] == 14402
assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)
@ -97,14 +101,18 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert torch.all(next_batch.attention_mask[0][:-2] == 0)
assert next_batch.input_ids.shape == (next_batch.size, 1)
assert next_batch.input_ids[0, 0] == 6208
assert next_batch.input_ids[0, 0] == 13
assert next_batch.input_lengths == [2]
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None
assert all([p[0].shape == (1, 12, 8, 64) for p in next_batch.past_key_values])
assert all([p[1].shape == (1, 12, 8, 64) for p in next_batch.past_key_values])
assert all(
[p[0].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
)
def test_causal_lm_generate_token_completion(
@ -119,10 +127,7 @@ def test_causal_lm_generate_token_completion(
assert next_batch is None
assert len(generated_texts) == 1
assert (
generated_texts[0].output
== "Test Test Test Test Test Test Test Test Test Test Test"
)
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
assert (
generated_texts[0].tokens
@ -145,7 +150,7 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "Test Test Test Test Test Test"
assert generated_texts[0].output == "Test.java:784)"
assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
)
@ -166,10 +171,7 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is None
assert len(generated_texts) == 1
assert (
generated_texts[0].output
== "Test Test Test Test Test Test Test Test Test Test Test"
)
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
)
@ -200,7 +202,8 @@ def test_batch_concatenate(
assert torch.all(next_batch.attention_mask[1:, :-2] == 0)
assert next_batch.batch_id == 0
assert torch.all(next_batch.input_ids == 6208)
assert next_batch.input_ids[0, 0] == 12355
assert torch.all(next_batch.input_ids[1:] == 13)
assert next_batch.input_lengths == [3, 2, 2]
assert next_batch.max_sequence_length == 3
@ -239,7 +242,7 @@ def test_batch_concatenate(
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "Test Test Test Test Test Test"
assert generated_texts[0].output == "Test.java:784)"
assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
)
@ -260,10 +263,7 @@ def test_batch_concatenate(
assert next_batch is not None
assert len(generated_texts) == 1
assert (
generated_texts[0].output
== "Test Test Test Test Test Test Test Test Test Test Test"
)
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
assert (
generated_texts[0].tokens
@ -283,10 +283,7 @@ def test_batch_concatenate(
assert next_batch is None
assert len(generated_texts) == 1
assert (
generated_texts[0].output
== "Test Test Test Test Test Test Test Test Test Test Test"
)
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
)

View File

@ -8,13 +8,13 @@ from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
@pytest.fixture
def default_pb_request(default_pb_parameters):
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_length=2,
parameters=default_pb_parameters,
max_new_tokens=10,
stopping_parameters=default_pb_stop_parameters,
)
@ -35,7 +35,7 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni
req_0 = copy(default_pb_request)
req_1 = default_pb_request
req_1.id = 1
req_1.max_new_tokens = 5
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu"))
@ -48,11 +48,12 @@ def default_seq2seq_lm():
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
batch = default_seq2seq_lm_batch
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
assert batch.batch_id == default_pb_batch.id
assert batch.requests == default_pb_batch.requests
assert batch.input_ids.shape == (default_pb_batch.size, 8)
assert batch.input_ids.shape == (default_pb_batch.size, sequence_length)
assert batch.input_ids[0][-2] == 4268
assert batch.input_ids[0][-1] == 1
assert torch.all(batch.input_ids[0][:-2] == 0)
@ -86,6 +87,7 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
generated_texts, next_batch = default_seq2seq_lm.generate_token(
default_seq2seq_lm_batch
)
@ -108,7 +110,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
assert next_batch.decoder_input_ids[0, 0] == 0
assert next_batch.decoder_input_ids[0, 1] == 259
assert next_batch.decoder_attention_mask is None
assert next_batch.encoder_last_hidden_state.shape == (1, 8, 512)
assert next_batch.encoder_last_hidden_state.shape == (1, sequence_length, 512)
assert next_batch.decoder_input_lengths == [2]
assert next_batch.max_decoder_input_length == 2
@ -121,10 +123,16 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
[p[1].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values]
)
assert all(
[p[2].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values]
[
p[2].shape == (next_batch.size, 6, sequence_length, 64)
for p in next_batch.past_key_values
]
)
assert all(
[p[3].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values]
[
p[3].shape == (next_batch.size, 6, sequence_length, 64)
for p in next_batch.past_key_values
]
)

View File

@ -4,10 +4,55 @@ from text_generation.utils import (
weight_hub_files,
download_weights,
weight_files,
StopSequenceCriteria,
StoppingCriteria,
LocalEntryNotFoundError,
)
def test_stop_sequence_criteria():
criteria = StopSequenceCriteria([1, 2, 3])
assert not criteria(1)
assert criteria.current_token_idx == 1
assert not criteria(2)
assert criteria.current_token_idx == 2
assert criteria(3)
assert criteria.current_token_idx == 3
def test_stop_sequence_criteria_reset():
criteria = StopSequenceCriteria([1, 2, 3])
assert not criteria(1)
assert criteria.current_token_idx == 1
assert not criteria(2)
assert criteria.current_token_idx == 2
assert not criteria(4)
assert criteria.current_token_idx == 0
def test_stop_sequence_criteria_empty():
with pytest.raises(ValueError):
StopSequenceCriteria([])
def test_stopping_criteria():
criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5)
assert criteria([1]) == (False, None)
assert criteria([1, 2]) == (False, None)
assert criteria([1, 2, 3]) == (True, "stop_sequence")
def test_stopping_criteria_max():
criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5)
assert criteria([1]) == (False, None)
assert criteria([1, 1]) == (False, None)
assert criteria([1, 1, 1]) == (False, None)
assert criteria([1, 1, 1, 1]) == (False, None)
assert criteria([1, 1, 1, 1, 1]) == (True, "length")
def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]

View File

@ -57,23 +57,17 @@ class CausalLMBatch:
for r in pb.requests:
inputs.append(r.inputs)
input_lengths.append(r.input_length)
next_token_choosers.append(
NextTokenChooser(
temperature=r.parameters.temperature,
top_k=r.parameters.top_k,
top_p=r.parameters.top_p,
do_sample=r.parameters.do_sample,
)
)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
stopping_criterias.append(
StoppingCriteria(
eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens
)
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
)
pad_to_multiple_of = 8 if "gpu" in str(device) else None
tokenized_inputs = tokenizer(
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of
inputs,
return_tensors="pt",
padding=True,
pad_to_multiple_of=pad_to_multiple_of,
).to(device)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
@ -123,8 +117,8 @@ class CausalLMBatch:
end_index = start_index + batch.size
# We only concatenate batches that did at least one step
if batch.input_ids.shape[1] > 1:
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
if batch.past_key_values is None:
raise ValueError("only concatenate prefilled batches")
# Create empty tensor
# input_ids is always of shape [batch_size, 1]
@ -331,14 +325,17 @@ class CausalLM(Model):
all_tokens = torch.cat([all_tokens, next_token])
# Evaluate stopping criteria
if stopping_criteria(all_tokens):
stop, reason = stopping_criteria(all_tokens)
if stop:
# Decode all tokens
output = self.tokenizer.decode(
all_tokens.squeeze(-1), skip_special_tokens=True
)
# Add to the list of finished generations with the original request
generated_texts.append(
GeneratedText(request, output, stopping_criteria.current_tokens)
GeneratedText(
request, output, stopping_criteria.current_tokens, reason
)
)
# add to the next batch
else:

View File

@ -94,18 +94,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs))
input_lengths.append(r.input_length)
next_token_choosers.append(
NextTokenChooser(
temperature=r.parameters.temperature,
top_k=r.parameters.top_k,
top_p=r.parameters.top_p,
do_sample=r.parameters.do_sample,
)
)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
stopping_criterias.append(
StoppingCriteria(
eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens
)
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
)
tokenized_inputs = tokenizer(
@ -207,11 +198,7 @@ class GalacticaSharded(Galactica):
continue
module_name, param_name = name.rsplit(".", 1)
try:
module = model.get_submodule(module_name)
except Exception as e:
print(type(model), name, module_name, param_name)
raise e
module = model.get_submodule(module_name)
current_tensor = parameters[name]
slice_ = f.get_slice(name)

View File

@ -68,24 +68,18 @@ class Seq2SeqLMBatch:
# Decoder sequence only contains the bos_token
decoder_input_ids.append(tokenizer.bos_token_id)
decoder_input_lengths.append(1)
next_token_choosers.append(
NextTokenChooser(
temperature=r.parameters.temperature,
top_k=r.parameters.top_k,
top_p=r.parameters.top_p,
do_sample=r.parameters.do_sample,
)
)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
stopping_criterias.append(
StoppingCriteria(
eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens
)
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
)
# Tokenize batch
pad_to_multiple_of = 8 if "gpu" in str(device) else None
tokenized_inputs = tokenizer(
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of
inputs,
return_tensors="pt",
padding=True,
pad_to_multiple_of=pad_to_multiple_of,
).to(device)
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
@ -431,12 +425,15 @@ class Seq2SeqLM(Model):
decoder_tokens = torch.cat([decoder_tokens, next_token.squeeze(1)])
# Evaluate stopping criteria
if stopping_criteria(decoder_tokens):
stop, reason = stopping_criteria(decoder_tokens)
if stop:
# Decode tokens
output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True)
# Add to the list of finished generations with the original request
generated_texts.append(
GeneratedText(request, output, stopping_criteria.current_tokens)
GeneratedText(
request, output, stopping_criteria.current_tokens, reason
)
)
# add to the next batch
else:

View File

@ -32,8 +32,12 @@ class GeneratedText:
request: generate_pb2.Request
output: str
tokens: int
reason: str
def to_pb(self) -> generate_pb2.GeneratedText:
return generate_pb2.GeneratedText(
request=self.request, output=self.output, tokens=self.tokens
request=self.request,
output=self.output,
tokens=self.tokens,
finish_reason=self.reason,
)

View File

@ -1,6 +1,5 @@
import concurrent
import os
import signal
import torch
import torch.distributed
@ -11,6 +10,8 @@ from functools import partial
from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm
from typing import List, Optional, Tuple
from transformers import AutoTokenizer
from transformers.generation.logits_process import (
LogitsProcessorList,
TemperatureLogitsWarper,
@ -18,6 +19,8 @@ from transformers.generation.logits_process import (
TopKLogitsWarper,
)
from text_generation.pb import generate_pb2
class Sampling:
def __call__(self, logits):
@ -56,20 +59,72 @@ class NextTokenChooser:
next_ids = self.choice(scores)
return next_ids.unsqueeze(-1)
@classmethod
def from_pb(cls, pb: generate_pb2.LogitsWarperParameters) -> "NextTokenChooser":
return NextTokenChooser(
temperature=pb.temperature,
top_k=pb.top_k,
top_p=pb.top_p,
do_sample=pb.do_sample,
)
class StopSequenceCriteria:
def __init__(self, tokens: List[int]):
if not tokens:
raise ValueError("tokens cannot be empty")
self.tokens = tokens
self.current_token_idx = 0
def __call__(self, last_token: int) -> bool:
if last_token == self.tokens[self.current_token_idx]:
# Increase idx to go to next token
self.current_token_idx += 1
else:
# Reset to first token of the stopping sequence
self.current_token_idx = 0
if self.current_token_idx == len(self.tokens):
# We matched the entire sequence without resetting
return True
return False
class StoppingCriteria:
def __init__(self, eos_token_id, max_new_tokens=20):
self.eos_token_id = eos_token_id
def __init__(
self, stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens=20
):
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
def __call__(self, all_ids):
def __call__(self, all_ids) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True
if self.eos_token_id is not None and all_ids[-1] == self.eos_token_id:
return True
return False
return True, "length"
last_token = all_ids[-1]
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(last_token):
return True, "stop_sequence"
return False, None
@classmethod
def from_pb(
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer
) -> "StoppingCriteria":
stop_sequence_criterias = []
for stop_sequence in pb.stop_sequences:
tokens = tokenizer(
stop_sequence, padding=False, return_attention_mask=False
).input_ids
if tokens:
stop_sequence_criterias.append(StopSequenceCriteria(tokens))
stop_sequence_criterias.append(StopSequenceCriteria([tokenizer.eos_token_id]))
return StoppingCriteria(stop_sequence_criterias, pb.max_new_tokens)
def initialize_torch_distributed():