feat: Support stop sequences (#7)
This commit is contained in:
parent
042180d88f
commit
718096f695
|
@ -15,6 +15,8 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
|
|||
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
||||
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
||||
- 45ms per token generation for BLOOM with 8xA100 80GB
|
||||
- Logits warpers (temperature scaling, topk ...)
|
||||
- Stop sequences
|
||||
|
||||
## Officially supported models
|
||||
|
||||
|
|
|
@ -28,12 +28,23 @@ message ClearCacheRequest {}
|
|||
message ClearCacheResponse {}
|
||||
|
||||
message LogitsWarperParameters {
|
||||
/// exponential scaling output probability distribution
|
||||
float temperature = 1;
|
||||
/// restricting to the k highest probability elements
|
||||
uint32 top_k = 2;
|
||||
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||
float top_p = 3;
|
||||
/// apply sampling on the logits
|
||||
bool do_sample = 4;
|
||||
}
|
||||
|
||||
message StoppingCriteriaParameters {
|
||||
/// Maximum number of generated tokens
|
||||
uint32 max_new_tokens = 1;
|
||||
/// Optional stopping sequences
|
||||
repeated string stop_sequences = 2;
|
||||
}
|
||||
|
||||
message Request {
|
||||
/// Request ID
|
||||
uint64 id = 1;
|
||||
|
@ -43,8 +54,8 @@ message Request {
|
|||
uint32 input_length = 3;
|
||||
/// Logits Warper Parameters
|
||||
LogitsWarperParameters parameters = 4;
|
||||
/// Stopping criteria
|
||||
uint32 max_new_tokens = 5;
|
||||
/// Stopping Criteria Parameters
|
||||
StoppingCriteriaParameters stopping_parameters = 5;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
|
@ -63,6 +74,8 @@ message GeneratedText {
|
|||
string output = 2;
|
||||
/// Number of generated tokens
|
||||
uint32 tokens = 3;
|
||||
/// Finish reason
|
||||
string finish_reason = 4;
|
||||
}
|
||||
|
||||
message GenerateRequest {
|
||||
|
|
|
@ -6,7 +6,9 @@ mod pb;
|
|||
mod sharded_client;
|
||||
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request};
|
||||
pub use pb::generate::v1::{
|
||||
Batch, GeneratedText, LogitsWarperParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
use thiserror::Error;
|
||||
use tonic::transport;
|
||||
|
|
|
@ -190,6 +190,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
|||
let response = InferResponse {
|
||||
output: output.output,
|
||||
tokens: output.tokens,
|
||||
finish_reason: output.finish_reason,
|
||||
queued: entry.time,
|
||||
start: entry.batch_time.unwrap(), // unwrap is always valid
|
||||
end: Instant::now(),
|
||||
|
@ -203,6 +204,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
|||
pub(crate) struct InferResponse {
|
||||
pub(crate) output: String,
|
||||
pub(crate) tokens: u32,
|
||||
pub(crate) finish_reason: String,
|
||||
pub(crate) queued: Instant,
|
||||
pub(crate) start: Instant,
|
||||
pub(crate) end: Instant,
|
||||
|
|
|
@ -4,7 +4,9 @@ use crate::{GenerateParameters, GenerateRequest};
|
|||
use parking_lot::Mutex;
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{Batch, ClientError, LogitsWarperParameters, Request};
|
||||
use text_generation_client::{
|
||||
Batch, ClientError, LogitsWarperParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use tokio::sync::oneshot::Sender;
|
||||
use tokio::time::Instant;
|
||||
|
||||
|
@ -72,7 +74,9 @@ impl State {
|
|||
parameters: Some(LogitsWarperParameters::from(
|
||||
entry.request.parameters.clone(),
|
||||
)),
|
||||
max_new_tokens: entry.request.parameters.max_new_tokens,
|
||||
stopping_parameters: Some(StoppingCriteriaParameters::from(
|
||||
entry.request.parameters.clone(),
|
||||
)),
|
||||
});
|
||||
|
||||
ids.push(*id);
|
||||
|
@ -168,3 +172,12 @@ impl From<GenerateParameters> for LogitsWarperParameters {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<GenerateParameters> for StoppingCriteriaParameters {
|
||||
fn from(parameters: GenerateParameters) -> Self {
|
||||
Self {
|
||||
stop_sequences: parameters.stop,
|
||||
max_new_tokens: parameters.max_new_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ pub(crate) struct GenerateParameters {
|
|||
pub do_sample: bool,
|
||||
#[serde(default = "default_max_new_tokens")]
|
||||
pub max_new_tokens: u32,
|
||||
pub stop: Vec<String>,
|
||||
}
|
||||
|
||||
fn default_temperature() -> f32 {
|
||||
|
@ -50,6 +51,7 @@ fn default_parameters() -> GenerateParameters {
|
|||
top_p: default_top_p(),
|
||||
do_sample: default_do_sample(),
|
||||
max_new_tokens: default_max_new_tokens(),
|
||||
stop: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -63,6 +65,7 @@ pub(crate) struct GenerateRequest {
|
|||
#[derive(Serialize)]
|
||||
pub(crate) struct GeneratedText {
|
||||
pub generated_text: String,
|
||||
pub finish_reason: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
|
|
@ -53,6 +53,7 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
|
|||
top_p: 1.0,
|
||||
do_sample: false,
|
||||
max_new_tokens: 1,
|
||||
stop: vec![],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
@ -88,11 +89,8 @@ async fn generate(
|
|||
})?;
|
||||
|
||||
// Validate request
|
||||
let (input_length, validated_request) = state
|
||||
.validation
|
||||
.validate(req.0)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
let (input_length, validated_request) =
|
||||
state.validation.validate(req.0).await.map_err(|err| {
|
||||
tracing::error!("{}", err.to_string());
|
||||
err
|
||||
})?;
|
||||
|
@ -148,6 +146,7 @@ async fn generate(
|
|||
// Send response
|
||||
let response = vec![GeneratedText {
|
||||
generated_text: response.output,
|
||||
finish_reason: response.finish_reason,
|
||||
}];
|
||||
Ok((headers, Json(response)))
|
||||
}
|
||||
|
|
|
@ -121,6 +121,14 @@ fn validation_worker(
|
|||
.unwrap_or(());
|
||||
continue;
|
||||
}
|
||||
if request.parameters.stop.len() > 4 {
|
||||
response_tx
|
||||
.send(Err(ValidationError::StopSequence(
|
||||
request.parameters.stop.len(),
|
||||
)))
|
||||
.unwrap_or(());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get the number of tokens in the input
|
||||
match tokenizer.encode(request.inputs.clone(), false) {
|
||||
|
@ -163,6 +171,8 @@ pub enum ValidationError {
|
|||
MaxNewTokens,
|
||||
#[error("inputs must have less than {1} tokens. Given: {0}")]
|
||||
InputLength(usize, usize),
|
||||
#[error("stop supports up to 4 stop sequences. Given: {0}")]
|
||||
StopSequence(usize),
|
||||
#[error("tokenizer error {0}")]
|
||||
Tokenizer(String),
|
||||
}
|
||||
|
|
|
@ -15,6 +15,11 @@ def default_pb_parameters():
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_stop_parameters():
|
||||
return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def bloom_560m_tokenizer():
|
||||
return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left")
|
||||
|
|
|
@ -9,13 +9,13 @@ from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_request(default_pb_parameters):
|
||||
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="Test",
|
||||
input_length=1,
|
||||
parameters=default_pb_parameters,
|
||||
max_new_tokens=10,
|
||||
stopping_parameters=default_pb_stop_parameters,
|
||||
)
|
||||
|
||||
|
||||
|
@ -36,7 +36,7 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer)
|
|||
req_0 = copy(default_pb_request)
|
||||
req_1 = default_pb_request
|
||||
req_1.id = 1
|
||||
req_1.max_new_tokens = 5
|
||||
req_1.stopping_parameters.max_new_tokens = 5
|
||||
|
||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||
return BloomCausalLMBatch.from_pb(
|
||||
|
@ -56,7 +56,6 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
|
|||
assert batch.requests == default_pb_batch.requests
|
||||
|
||||
assert len(batch.input_ids) == default_pb_batch.size
|
||||
assert len(batch.input_ids[0]) == 8
|
||||
assert batch.input_ids[0][-1] == 10264
|
||||
assert torch.all(batch.input_ids[0][:-1] == 3)
|
||||
|
||||
|
@ -85,6 +84,7 @@ def test_causal_lm_batch_type(default_bloom):
|
|||
|
||||
|
||||
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
||||
sequence_length = len(default_bloom_batch.all_input_ids[0])
|
||||
generated_texts, next_batch = default_bloom.generate_token(default_bloom_batch)
|
||||
|
||||
assert generated_texts == []
|
||||
|
@ -92,7 +92,11 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
|||
assert not next_batch.keys_head_dim_last
|
||||
|
||||
assert len(next_batch.all_input_ids) == next_batch.size
|
||||
assert len(next_batch.all_input_ids[0]) == len(next_batch.attention_mask[0]) == 9
|
||||
assert (
|
||||
len(next_batch.all_input_ids[0])
|
||||
== len(next_batch.attention_mask[0])
|
||||
== sequence_length + 1
|
||||
)
|
||||
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
|
||||
assert torch.all(next_batch.all_input_ids[0][:-2] == 3)
|
||||
|
||||
|
@ -106,8 +110,12 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
|||
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
|
||||
|
||||
assert next_batch.past_key_values is not None
|
||||
assert all([p[0].shape == (16, 64, 8) for p in next_batch.past_key_values])
|
||||
assert all([p[1].shape == (16, 8, 64) for p in next_batch.past_key_values])
|
||||
assert all(
|
||||
[p[0].shape == (16, 64, sequence_length) for p in next_batch.past_key_values]
|
||||
)
|
||||
assert all(
|
||||
[p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
|
||||
|
||||
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
|
||||
|
|
|
@ -8,13 +8,13 @@ from text_generation.models.causal_lm import CausalLM, CausalLMBatch
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_request(default_pb_parameters):
|
||||
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="Test",
|
||||
input_length=1,
|
||||
parameters=default_pb_parameters,
|
||||
max_new_tokens=10,
|
||||
stopping_parameters=default_pb_stop_parameters,
|
||||
)
|
||||
|
||||
|
||||
|
@ -33,7 +33,7 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
|
|||
req_0 = copy(default_pb_request)
|
||||
req_1 = default_pb_request
|
||||
req_1.id = 1
|
||||
req_1.max_new_tokens = 5
|
||||
req_1.stopping_parameters.max_new_tokens = 5
|
||||
|
||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
|
||||
|
@ -51,7 +51,6 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
|||
assert batch.requests == default_pb_batch.requests
|
||||
|
||||
assert len(batch.input_ids) == default_pb_batch.size
|
||||
assert len(batch.input_ids[0]) == 8
|
||||
assert batch.input_ids[0][-1] == 14402
|
||||
assert torch.all(batch.input_ids[0][:-1] == 50256)
|
||||
|
||||
|
@ -80,6 +79,7 @@ def test_causal_lm_batch_type(default_causal_lm):
|
|||
|
||||
|
||||
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
||||
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(
|
||||
default_causal_lm_batch
|
||||
)
|
||||
|
@ -88,8 +88,12 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
|||
assert isinstance(next_batch, CausalLMBatch)
|
||||
|
||||
assert len(next_batch.all_input_ids) == next_batch.size
|
||||
assert len(next_batch.all_input_ids[0]) == len(next_batch.attention_mask[0]) == 9
|
||||
assert next_batch.all_input_ids[0][-1] == 6208
|
||||
assert (
|
||||
len(next_batch.all_input_ids[0])
|
||||
== len(next_batch.attention_mask[0])
|
||||
== sequence_length + 1
|
||||
)
|
||||
assert next_batch.all_input_ids[0][-1] == 13
|
||||
assert next_batch.all_input_ids[0][-2] == 14402
|
||||
assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)
|
||||
|
||||
|
@ -97,14 +101,18 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
|||
assert torch.all(next_batch.attention_mask[0][:-2] == 0)
|
||||
|
||||
assert next_batch.input_ids.shape == (next_batch.size, 1)
|
||||
assert next_batch.input_ids[0, 0] == 6208
|
||||
assert next_batch.input_ids[0, 0] == 13
|
||||
|
||||
assert next_batch.input_lengths == [2]
|
||||
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
|
||||
|
||||
assert next_batch.past_key_values is not None
|
||||
assert all([p[0].shape == (1, 12, 8, 64) for p in next_batch.past_key_values])
|
||||
assert all([p[1].shape == (1, 12, 8, 64) for p in next_batch.past_key_values])
|
||||
assert all(
|
||||
[p[0].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
assert all(
|
||||
[p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
|
||||
|
||||
def test_causal_lm_generate_token_completion(
|
||||
|
@ -119,10 +127,7 @@ def test_causal_lm_generate_token_completion(
|
|||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert (
|
||||
generated_texts[0].output
|
||||
== "Test Test Test Test Test Test Test Test Test Test Test"
|
||||
)
|
||||
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
|
||||
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
|
@ -145,7 +150,7 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "Test Test Test Test Test Test"
|
||||
assert generated_texts[0].output == "Test.java:784)"
|
||||
assert (
|
||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
|
||||
)
|
||||
|
@ -166,10 +171,7 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert (
|
||||
generated_texts[0].output
|
||||
== "Test Test Test Test Test Test Test Test Test Test Test"
|
||||
)
|
||||
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
|
||||
assert (
|
||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
|
||||
)
|
||||
|
@ -200,7 +202,8 @@ def test_batch_concatenate(
|
|||
assert torch.all(next_batch.attention_mask[1:, :-2] == 0)
|
||||
|
||||
assert next_batch.batch_id == 0
|
||||
assert torch.all(next_batch.input_ids == 6208)
|
||||
assert next_batch.input_ids[0, 0] == 12355
|
||||
assert torch.all(next_batch.input_ids[1:] == 13)
|
||||
|
||||
assert next_batch.input_lengths == [3, 2, 2]
|
||||
assert next_batch.max_sequence_length == 3
|
||||
|
@ -239,7 +242,7 @@ def test_batch_concatenate(
|
|||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "Test Test Test Test Test Test"
|
||||
assert generated_texts[0].output == "Test.java:784)"
|
||||
assert (
|
||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
|
||||
)
|
||||
|
@ -260,10 +263,7 @@ def test_batch_concatenate(
|
|||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert (
|
||||
generated_texts[0].output
|
||||
== "Test Test Test Test Test Test Test Test Test Test Test"
|
||||
)
|
||||
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
|
||||
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
|
@ -283,10 +283,7 @@ def test_batch_concatenate(
|
|||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert (
|
||||
generated_texts[0].output
|
||||
== "Test Test Test Test Test Test Test Test Test Test Test"
|
||||
)
|
||||
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
|
||||
assert (
|
||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
|
||||
)
|
||||
|
|
|
@ -8,13 +8,13 @@ from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_request(default_pb_parameters):
|
||||
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="Test",
|
||||
input_length=2,
|
||||
parameters=default_pb_parameters,
|
||||
max_new_tokens=10,
|
||||
stopping_parameters=default_pb_stop_parameters,
|
||||
)
|
||||
|
||||
|
||||
|
@ -35,7 +35,7 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni
|
|||
req_0 = copy(default_pb_request)
|
||||
req_1 = default_pb_request
|
||||
req_1.id = 1
|
||||
req_1.max_new_tokens = 5
|
||||
req_1.stopping_parameters.max_new_tokens = 5
|
||||
|
||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||
return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu"))
|
||||
|
@ -48,11 +48,12 @@ def default_seq2seq_lm():
|
|||
|
||||
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
|
||||
batch = default_seq2seq_lm_batch
|
||||
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
|
||||
|
||||
assert batch.batch_id == default_pb_batch.id
|
||||
assert batch.requests == default_pb_batch.requests
|
||||
|
||||
assert batch.input_ids.shape == (default_pb_batch.size, 8)
|
||||
assert batch.input_ids.shape == (default_pb_batch.size, sequence_length)
|
||||
assert batch.input_ids[0][-2] == 4268
|
||||
assert batch.input_ids[0][-1] == 1
|
||||
assert torch.all(batch.input_ids[0][:-2] == 0)
|
||||
|
@ -86,6 +87,7 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
|
|||
|
||||
|
||||
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
|
||||
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(
|
||||
default_seq2seq_lm_batch
|
||||
)
|
||||
|
@ -108,7 +110,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
|
|||
assert next_batch.decoder_input_ids[0, 0] == 0
|
||||
assert next_batch.decoder_input_ids[0, 1] == 259
|
||||
assert next_batch.decoder_attention_mask is None
|
||||
assert next_batch.encoder_last_hidden_state.shape == (1, 8, 512)
|
||||
assert next_batch.encoder_last_hidden_state.shape == (1, sequence_length, 512)
|
||||
|
||||
assert next_batch.decoder_input_lengths == [2]
|
||||
assert next_batch.max_decoder_input_length == 2
|
||||
|
@ -121,10 +123,16 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
|
|||
[p[1].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
assert all(
|
||||
[p[2].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values]
|
||||
[
|
||||
p[2].shape == (next_batch.size, 6, sequence_length, 64)
|
||||
for p in next_batch.past_key_values
|
||||
]
|
||||
)
|
||||
assert all(
|
||||
[p[3].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values]
|
||||
[
|
||||
p[3].shape == (next_batch.size, 6, sequence_length, 64)
|
||||
for p in next_batch.past_key_values
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -4,10 +4,55 @@ from text_generation.utils import (
|
|||
weight_hub_files,
|
||||
download_weights,
|
||||
weight_files,
|
||||
StopSequenceCriteria,
|
||||
StoppingCriteria,
|
||||
LocalEntryNotFoundError,
|
||||
)
|
||||
|
||||
|
||||
def test_stop_sequence_criteria():
|
||||
criteria = StopSequenceCriteria([1, 2, 3])
|
||||
|
||||
assert not criteria(1)
|
||||
assert criteria.current_token_idx == 1
|
||||
assert not criteria(2)
|
||||
assert criteria.current_token_idx == 2
|
||||
assert criteria(3)
|
||||
assert criteria.current_token_idx == 3
|
||||
|
||||
|
||||
def test_stop_sequence_criteria_reset():
|
||||
criteria = StopSequenceCriteria([1, 2, 3])
|
||||
|
||||
assert not criteria(1)
|
||||
assert criteria.current_token_idx == 1
|
||||
assert not criteria(2)
|
||||
assert criteria.current_token_idx == 2
|
||||
assert not criteria(4)
|
||||
assert criteria.current_token_idx == 0
|
||||
|
||||
|
||||
def test_stop_sequence_criteria_empty():
|
||||
with pytest.raises(ValueError):
|
||||
StopSequenceCriteria([])
|
||||
|
||||
|
||||
def test_stopping_criteria():
|
||||
criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5)
|
||||
assert criteria([1]) == (False, None)
|
||||
assert criteria([1, 2]) == (False, None)
|
||||
assert criteria([1, 2, 3]) == (True, "stop_sequence")
|
||||
|
||||
|
||||
def test_stopping_criteria_max():
|
||||
criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5)
|
||||
assert criteria([1]) == (False, None)
|
||||
assert criteria([1, 1]) == (False, None)
|
||||
assert criteria([1, 1, 1]) == (False, None)
|
||||
assert criteria([1, 1, 1, 1]) == (False, None)
|
||||
assert criteria([1, 1, 1, 1, 1]) == (True, "length")
|
||||
|
||||
|
||||
def test_weight_hub_files():
|
||||
filenames = weight_hub_files("bigscience/bloom-560m")
|
||||
assert filenames == ["model.safetensors"]
|
||||
|
|
|
@ -57,23 +57,17 @@ class CausalLMBatch:
|
|||
for r in pb.requests:
|
||||
inputs.append(r.inputs)
|
||||
input_lengths.append(r.input_length)
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser(
|
||||
temperature=r.parameters.temperature,
|
||||
top_k=r.parameters.top_k,
|
||||
top_p=r.parameters.top_p,
|
||||
do_sample=r.parameters.do_sample,
|
||||
)
|
||||
)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
|
||||
stopping_criterias.append(
|
||||
StoppingCriteria(
|
||||
eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens
|
||||
)
|
||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||
)
|
||||
|
||||
pad_to_multiple_of = 8 if "gpu" in str(device) else None
|
||||
tokenized_inputs = tokenizer(
|
||||
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of
|
||||
inputs,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
).to(device)
|
||||
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
||||
|
||||
|
@ -123,8 +117,8 @@ class CausalLMBatch:
|
|||
end_index = start_index + batch.size
|
||||
|
||||
# We only concatenate batches that did at least one step
|
||||
if batch.input_ids.shape[1] > 1:
|
||||
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
|
||||
if batch.past_key_values is None:
|
||||
raise ValueError("only concatenate prefilled batches")
|
||||
|
||||
# Create empty tensor
|
||||
# input_ids is always of shape [batch_size, 1]
|
||||
|
@ -331,14 +325,17 @@ class CausalLM(Model):
|
|||
all_tokens = torch.cat([all_tokens, next_token])
|
||||
|
||||
# Evaluate stopping criteria
|
||||
if stopping_criteria(all_tokens):
|
||||
stop, reason = stopping_criteria(all_tokens)
|
||||
if stop:
|
||||
# Decode all tokens
|
||||
output = self.tokenizer.decode(
|
||||
all_tokens.squeeze(-1), skip_special_tokens=True
|
||||
)
|
||||
# Add to the list of finished generations with the original request
|
||||
generated_texts.append(
|
||||
GeneratedText(request, output, stopping_criteria.current_tokens)
|
||||
GeneratedText(
|
||||
request, output, stopping_criteria.current_tokens, reason
|
||||
)
|
||||
)
|
||||
# add to the next batch
|
||||
else:
|
||||
|
|
|
@ -94,18 +94,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||
input_lengths.append(r.input_length)
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser(
|
||||
temperature=r.parameters.temperature,
|
||||
top_k=r.parameters.top_k,
|
||||
top_p=r.parameters.top_p,
|
||||
do_sample=r.parameters.do_sample,
|
||||
)
|
||||
)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
|
||||
stopping_criterias.append(
|
||||
StoppingCriteria(
|
||||
eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens
|
||||
)
|
||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||
)
|
||||
|
||||
tokenized_inputs = tokenizer(
|
||||
|
@ -207,11 +198,7 @@ class GalacticaSharded(Galactica):
|
|||
continue
|
||||
|
||||
module_name, param_name = name.rsplit(".", 1)
|
||||
try:
|
||||
module = model.get_submodule(module_name)
|
||||
except Exception as e:
|
||||
print(type(model), name, module_name, param_name)
|
||||
raise e
|
||||
module = model.get_submodule(module_name)
|
||||
current_tensor = parameters[name]
|
||||
|
||||
slice_ = f.get_slice(name)
|
||||
|
|
|
@ -68,24 +68,18 @@ class Seq2SeqLMBatch:
|
|||
# Decoder sequence only contains the bos_token
|
||||
decoder_input_ids.append(tokenizer.bos_token_id)
|
||||
decoder_input_lengths.append(1)
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser(
|
||||
temperature=r.parameters.temperature,
|
||||
top_k=r.parameters.top_k,
|
||||
top_p=r.parameters.top_p,
|
||||
do_sample=r.parameters.do_sample,
|
||||
)
|
||||
)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
|
||||
stopping_criterias.append(
|
||||
StoppingCriteria(
|
||||
eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens
|
||||
)
|
||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||
)
|
||||
|
||||
# Tokenize batch
|
||||
pad_to_multiple_of = 8 if "gpu" in str(device) else None
|
||||
tokenized_inputs = tokenizer(
|
||||
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of
|
||||
inputs,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
).to(device)
|
||||
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
||||
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
|
||||
|
@ -431,12 +425,15 @@ class Seq2SeqLM(Model):
|
|||
decoder_tokens = torch.cat([decoder_tokens, next_token.squeeze(1)])
|
||||
|
||||
# Evaluate stopping criteria
|
||||
if stopping_criteria(decoder_tokens):
|
||||
stop, reason = stopping_criteria(decoder_tokens)
|
||||
if stop:
|
||||
# Decode tokens
|
||||
output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True)
|
||||
# Add to the list of finished generations with the original request
|
||||
generated_texts.append(
|
||||
GeneratedText(request, output, stopping_criteria.current_tokens)
|
||||
GeneratedText(
|
||||
request, output, stopping_criteria.current_tokens, reason
|
||||
)
|
||||
)
|
||||
# add to the next batch
|
||||
else:
|
||||
|
|
|
@ -32,8 +32,12 @@ class GeneratedText:
|
|||
request: generate_pb2.Request
|
||||
output: str
|
||||
tokens: int
|
||||
reason: str
|
||||
|
||||
def to_pb(self) -> generate_pb2.GeneratedText:
|
||||
return generate_pb2.GeneratedText(
|
||||
request=self.request, output=self.output, tokens=self.tokens
|
||||
request=self.request,
|
||||
output=self.output,
|
||||
tokens=self.tokens,
|
||||
finish_reason=self.reason,
|
||||
)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import concurrent
|
||||
import os
|
||||
import signal
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -11,6 +10,8 @@ from functools import partial
|
|||
from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
|
||||
from huggingface_hub.utils import LocalEntryNotFoundError
|
||||
from tqdm import tqdm
|
||||
from typing import List, Optional, Tuple
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.generation.logits_process import (
|
||||
LogitsProcessorList,
|
||||
TemperatureLogitsWarper,
|
||||
|
@ -18,6 +19,8 @@ from transformers.generation.logits_process import (
|
|||
TopKLogitsWarper,
|
||||
)
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
|
||||
|
||||
class Sampling:
|
||||
def __call__(self, logits):
|
||||
|
@ -56,20 +59,72 @@ class NextTokenChooser:
|
|||
next_ids = self.choice(scores)
|
||||
return next_ids.unsqueeze(-1)
|
||||
|
||||
@classmethod
|
||||
def from_pb(cls, pb: generate_pb2.LogitsWarperParameters) -> "NextTokenChooser":
|
||||
return NextTokenChooser(
|
||||
temperature=pb.temperature,
|
||||
top_k=pb.top_k,
|
||||
top_p=pb.top_p,
|
||||
do_sample=pb.do_sample,
|
||||
)
|
||||
|
||||
|
||||
class StopSequenceCriteria:
|
||||
def __init__(self, tokens: List[int]):
|
||||
if not tokens:
|
||||
raise ValueError("tokens cannot be empty")
|
||||
|
||||
self.tokens = tokens
|
||||
self.current_token_idx = 0
|
||||
|
||||
def __call__(self, last_token: int) -> bool:
|
||||
if last_token == self.tokens[self.current_token_idx]:
|
||||
# Increase idx to go to next token
|
||||
self.current_token_idx += 1
|
||||
else:
|
||||
# Reset to first token of the stopping sequence
|
||||
self.current_token_idx = 0
|
||||
|
||||
if self.current_token_idx == len(self.tokens):
|
||||
# We matched the entire sequence without resetting
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class StoppingCriteria:
|
||||
def __init__(self, eos_token_id, max_new_tokens=20):
|
||||
self.eos_token_id = eos_token_id
|
||||
def __init__(
|
||||
self, stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens=20
|
||||
):
|
||||
self.stop_sequence_criterias = stop_sequence_criterias
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.current_tokens = 0
|
||||
|
||||
def __call__(self, all_ids):
|
||||
def __call__(self, all_ids) -> Tuple[bool, Optional[str]]:
|
||||
self.current_tokens += 1
|
||||
if self.current_tokens >= self.max_new_tokens:
|
||||
return True
|
||||
if self.eos_token_id is not None and all_ids[-1] == self.eos_token_id:
|
||||
return True
|
||||
return False
|
||||
return True, "length"
|
||||
|
||||
last_token = all_ids[-1]
|
||||
for stop_sequence_criteria in self.stop_sequence_criterias:
|
||||
if stop_sequence_criteria(last_token):
|
||||
return True, "stop_sequence"
|
||||
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer
|
||||
) -> "StoppingCriteria":
|
||||
stop_sequence_criterias = []
|
||||
for stop_sequence in pb.stop_sequences:
|
||||
tokens = tokenizer(
|
||||
stop_sequence, padding=False, return_attention_mask=False
|
||||
).input_ids
|
||||
if tokens:
|
||||
stop_sequence_criterias.append(StopSequenceCriteria(tokens))
|
||||
stop_sequence_criterias.append(StopSequenceCriteria([tokenizer.eos_token_id]))
|
||||
|
||||
return StoppingCriteria(stop_sequence_criterias, pb.max_new_tokens)
|
||||
|
||||
|
||||
def initialize_torch_distributed():
|
||||
|
|
Loading…
Reference in New Issue