feat: Support stop sequences (#7)
This commit is contained in:
parent
042180d88f
commit
718096f695
|
@ -15,6 +15,8 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
|
||||||
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
||||||
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
||||||
- 45ms per token generation for BLOOM with 8xA100 80GB
|
- 45ms per token generation for BLOOM with 8xA100 80GB
|
||||||
|
- Logits warpers (temperature scaling, topk ...)
|
||||||
|
- Stop sequences
|
||||||
|
|
||||||
## Officially supported models
|
## Officially supported models
|
||||||
|
|
||||||
|
|
|
@ -28,12 +28,23 @@ message ClearCacheRequest {}
|
||||||
message ClearCacheResponse {}
|
message ClearCacheResponse {}
|
||||||
|
|
||||||
message LogitsWarperParameters {
|
message LogitsWarperParameters {
|
||||||
|
/// exponential scaling output probability distribution
|
||||||
float temperature = 1;
|
float temperature = 1;
|
||||||
|
/// restricting to the k highest probability elements
|
||||||
uint32 top_k = 2;
|
uint32 top_k = 2;
|
||||||
|
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
float top_p = 3;
|
float top_p = 3;
|
||||||
|
/// apply sampling on the logits
|
||||||
bool do_sample = 4;
|
bool do_sample = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message StoppingCriteriaParameters {
|
||||||
|
/// Maximum number of generated tokens
|
||||||
|
uint32 max_new_tokens = 1;
|
||||||
|
/// Optional stopping sequences
|
||||||
|
repeated string stop_sequences = 2;
|
||||||
|
}
|
||||||
|
|
||||||
message Request {
|
message Request {
|
||||||
/// Request ID
|
/// Request ID
|
||||||
uint64 id = 1;
|
uint64 id = 1;
|
||||||
|
@ -43,8 +54,8 @@ message Request {
|
||||||
uint32 input_length = 3;
|
uint32 input_length = 3;
|
||||||
/// Logits Warper Parameters
|
/// Logits Warper Parameters
|
||||||
LogitsWarperParameters parameters = 4;
|
LogitsWarperParameters parameters = 4;
|
||||||
/// Stopping criteria
|
/// Stopping Criteria Parameters
|
||||||
uint32 max_new_tokens = 5;
|
StoppingCriteriaParameters stopping_parameters = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
|
@ -63,6 +74,8 @@ message GeneratedText {
|
||||||
string output = 2;
|
string output = 2;
|
||||||
/// Number of generated tokens
|
/// Number of generated tokens
|
||||||
uint32 tokens = 3;
|
uint32 tokens = 3;
|
||||||
|
/// Finish reason
|
||||||
|
string finish_reason = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateRequest {
|
message GenerateRequest {
|
||||||
|
|
|
@ -6,7 +6,9 @@ mod pb;
|
||||||
mod sharded_client;
|
mod sharded_client;
|
||||||
|
|
||||||
pub use client::Client;
|
pub use client::Client;
|
||||||
pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request};
|
pub use pb::generate::v1::{
|
||||||
|
Batch, GeneratedText, LogitsWarperParameters, Request, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tonic::transport;
|
use tonic::transport;
|
||||||
|
|
|
@ -190,6 +190,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
||||||
let response = InferResponse {
|
let response = InferResponse {
|
||||||
output: output.output,
|
output: output.output,
|
||||||
tokens: output.tokens,
|
tokens: output.tokens,
|
||||||
|
finish_reason: output.finish_reason,
|
||||||
queued: entry.time,
|
queued: entry.time,
|
||||||
start: entry.batch_time.unwrap(), // unwrap is always valid
|
start: entry.batch_time.unwrap(), // unwrap is always valid
|
||||||
end: Instant::now(),
|
end: Instant::now(),
|
||||||
|
@ -203,6 +204,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
||||||
pub(crate) struct InferResponse {
|
pub(crate) struct InferResponse {
|
||||||
pub(crate) output: String,
|
pub(crate) output: String,
|
||||||
pub(crate) tokens: u32,
|
pub(crate) tokens: u32,
|
||||||
|
pub(crate) finish_reason: String,
|
||||||
pub(crate) queued: Instant,
|
pub(crate) queued: Instant,
|
||||||
pub(crate) start: Instant,
|
pub(crate) start: Instant,
|
||||||
pub(crate) end: Instant,
|
pub(crate) end: Instant,
|
||||||
|
|
|
@ -4,7 +4,9 @@ use crate::{GenerateParameters, GenerateRequest};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{Batch, ClientError, LogitsWarperParameters, Request};
|
use text_generation_client::{
|
||||||
|
Batch, ClientError, LogitsWarperParameters, Request, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
use tokio::sync::oneshot::Sender;
|
use tokio::sync::oneshot::Sender;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
|
||||||
|
@ -72,7 +74,9 @@ impl State {
|
||||||
parameters: Some(LogitsWarperParameters::from(
|
parameters: Some(LogitsWarperParameters::from(
|
||||||
entry.request.parameters.clone(),
|
entry.request.parameters.clone(),
|
||||||
)),
|
)),
|
||||||
max_new_tokens: entry.request.parameters.max_new_tokens,
|
stopping_parameters: Some(StoppingCriteriaParameters::from(
|
||||||
|
entry.request.parameters.clone(),
|
||||||
|
)),
|
||||||
});
|
});
|
||||||
|
|
||||||
ids.push(*id);
|
ids.push(*id);
|
||||||
|
@ -168,3 +172,12 @@ impl From<GenerateParameters> for LogitsWarperParameters {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<GenerateParameters> for StoppingCriteriaParameters {
|
||||||
|
fn from(parameters: GenerateParameters) -> Self {
|
||||||
|
Self {
|
||||||
|
stop_sequences: parameters.stop,
|
||||||
|
max_new_tokens: parameters.max_new_tokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ pub(crate) struct GenerateParameters {
|
||||||
pub do_sample: bool,
|
pub do_sample: bool,
|
||||||
#[serde(default = "default_max_new_tokens")]
|
#[serde(default = "default_max_new_tokens")]
|
||||||
pub max_new_tokens: u32,
|
pub max_new_tokens: u32,
|
||||||
|
pub stop: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_temperature() -> f32 {
|
fn default_temperature() -> f32 {
|
||||||
|
@ -50,6 +51,7 @@ fn default_parameters() -> GenerateParameters {
|
||||||
top_p: default_top_p(),
|
top_p: default_top_p(),
|
||||||
do_sample: default_do_sample(),
|
do_sample: default_do_sample(),
|
||||||
max_new_tokens: default_max_new_tokens(),
|
max_new_tokens: default_max_new_tokens(),
|
||||||
|
stop: vec![],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,6 +65,7 @@ pub(crate) struct GenerateRequest {
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
pub(crate) struct GeneratedText {
|
pub(crate) struct GeneratedText {
|
||||||
pub generated_text: String,
|
pub generated_text: String,
|
||||||
|
pub finish_reason: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
|
|
|
@ -53,6 +53,7 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
|
||||||
top_p: 1.0,
|
top_p: 1.0,
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
|
stop: vec![],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -88,11 +89,8 @@ async fn generate(
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Validate request
|
// Validate request
|
||||||
let (input_length, validated_request) = state
|
let (input_length, validated_request) =
|
||||||
.validation
|
state.validation.validate(req.0).await.map_err(|err| {
|
||||||
.validate(req.0)
|
|
||||||
.await
|
|
||||||
.map_err(|err| {
|
|
||||||
tracing::error!("{}", err.to_string());
|
tracing::error!("{}", err.to_string());
|
||||||
err
|
err
|
||||||
})?;
|
})?;
|
||||||
|
@ -148,6 +146,7 @@ async fn generate(
|
||||||
// Send response
|
// Send response
|
||||||
let response = vec![GeneratedText {
|
let response = vec![GeneratedText {
|
||||||
generated_text: response.output,
|
generated_text: response.output,
|
||||||
|
finish_reason: response.finish_reason,
|
||||||
}];
|
}];
|
||||||
Ok((headers, Json(response)))
|
Ok((headers, Json(response)))
|
||||||
}
|
}
|
||||||
|
|
|
@ -121,6 +121,14 @@ fn validation_worker(
|
||||||
.unwrap_or(());
|
.unwrap_or(());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if request.parameters.stop.len() > 4 {
|
||||||
|
response_tx
|
||||||
|
.send(Err(ValidationError::StopSequence(
|
||||||
|
request.parameters.stop.len(),
|
||||||
|
)))
|
||||||
|
.unwrap_or(());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// Get the number of tokens in the input
|
// Get the number of tokens in the input
|
||||||
match tokenizer.encode(request.inputs.clone(), false) {
|
match tokenizer.encode(request.inputs.clone(), false) {
|
||||||
|
@ -163,6 +171,8 @@ pub enum ValidationError {
|
||||||
MaxNewTokens,
|
MaxNewTokens,
|
||||||
#[error("inputs must have less than {1} tokens. Given: {0}")]
|
#[error("inputs must have less than {1} tokens. Given: {0}")]
|
||||||
InputLength(usize, usize),
|
InputLength(usize, usize),
|
||||||
|
#[error("stop supports up to 4 stop sequences. Given: {0}")]
|
||||||
|
StopSequence(usize),
|
||||||
#[error("tokenizer error {0}")]
|
#[error("tokenizer error {0}")]
|
||||||
Tokenizer(String),
|
Tokenizer(String),
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,11 @@ def default_pb_parameters():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def default_pb_stop_parameters():
|
||||||
|
return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def bloom_560m_tokenizer():
|
def bloom_560m_tokenizer():
|
||||||
return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left")
|
return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left")
|
||||||
|
|
|
@ -9,13 +9,13 @@ from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_pb_request(default_pb_parameters):
|
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
input_length=1,
|
input_length=1,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
max_new_tokens=10,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer)
|
||||||
req_0 = copy(default_pb_request)
|
req_0 = copy(default_pb_request)
|
||||||
req_1 = default_pb_request
|
req_1 = default_pb_request
|
||||||
req_1.id = 1
|
req_1.id = 1
|
||||||
req_1.max_new_tokens = 5
|
req_1.stopping_parameters.max_new_tokens = 5
|
||||||
|
|
||||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||||
return BloomCausalLMBatch.from_pb(
|
return BloomCausalLMBatch.from_pb(
|
||||||
|
@ -56,7 +56,6 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
|
||||||
assert batch.requests == default_pb_batch.requests
|
assert batch.requests == default_pb_batch.requests
|
||||||
|
|
||||||
assert len(batch.input_ids) == default_pb_batch.size
|
assert len(batch.input_ids) == default_pb_batch.size
|
||||||
assert len(batch.input_ids[0]) == 8
|
|
||||||
assert batch.input_ids[0][-1] == 10264
|
assert batch.input_ids[0][-1] == 10264
|
||||||
assert torch.all(batch.input_ids[0][:-1] == 3)
|
assert torch.all(batch.input_ids[0][:-1] == 3)
|
||||||
|
|
||||||
|
@ -85,6 +84,7 @@ def test_causal_lm_batch_type(default_bloom):
|
||||||
|
|
||||||
|
|
||||||
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
||||||
|
sequence_length = len(default_bloom_batch.all_input_ids[0])
|
||||||
generated_texts, next_batch = default_bloom.generate_token(default_bloom_batch)
|
generated_texts, next_batch = default_bloom.generate_token(default_bloom_batch)
|
||||||
|
|
||||||
assert generated_texts == []
|
assert generated_texts == []
|
||||||
|
@ -92,7 +92,11 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
||||||
assert not next_batch.keys_head_dim_last
|
assert not next_batch.keys_head_dim_last
|
||||||
|
|
||||||
assert len(next_batch.all_input_ids) == next_batch.size
|
assert len(next_batch.all_input_ids) == next_batch.size
|
||||||
assert len(next_batch.all_input_ids[0]) == len(next_batch.attention_mask[0]) == 9
|
assert (
|
||||||
|
len(next_batch.all_input_ids[0])
|
||||||
|
== len(next_batch.attention_mask[0])
|
||||||
|
== sequence_length + 1
|
||||||
|
)
|
||||||
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
|
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
|
||||||
assert torch.all(next_batch.all_input_ids[0][:-2] == 3)
|
assert torch.all(next_batch.all_input_ids[0][:-2] == 3)
|
||||||
|
|
||||||
|
@ -106,8 +110,12 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
||||||
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
|
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
|
||||||
|
|
||||||
assert next_batch.past_key_values is not None
|
assert next_batch.past_key_values is not None
|
||||||
assert all([p[0].shape == (16, 64, 8) for p in next_batch.past_key_values])
|
assert all(
|
||||||
assert all([p[1].shape == (16, 8, 64) for p in next_batch.past_key_values])
|
[p[0].shape == (16, 64, sequence_length) for p in next_batch.past_key_values]
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
[p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
|
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
|
||||||
|
|
|
@ -8,13 +8,13 @@ from text_generation.models.causal_lm import CausalLM, CausalLMBatch
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_pb_request(default_pb_parameters):
|
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
input_length=1,
|
input_length=1,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
max_new_tokens=10,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
|
||||||
req_0 = copy(default_pb_request)
|
req_0 = copy(default_pb_request)
|
||||||
req_1 = default_pb_request
|
req_1 = default_pb_request
|
||||||
req_1.id = 1
|
req_1.id = 1
|
||||||
req_1.max_new_tokens = 5
|
req_1.stopping_parameters.max_new_tokens = 5
|
||||||
|
|
||||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||||
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
|
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
|
||||||
|
@ -51,7 +51,6 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
||||||
assert batch.requests == default_pb_batch.requests
|
assert batch.requests == default_pb_batch.requests
|
||||||
|
|
||||||
assert len(batch.input_ids) == default_pb_batch.size
|
assert len(batch.input_ids) == default_pb_batch.size
|
||||||
assert len(batch.input_ids[0]) == 8
|
|
||||||
assert batch.input_ids[0][-1] == 14402
|
assert batch.input_ids[0][-1] == 14402
|
||||||
assert torch.all(batch.input_ids[0][:-1] == 50256)
|
assert torch.all(batch.input_ids[0][:-1] == 50256)
|
||||||
|
|
||||||
|
@ -80,6 +79,7 @@ def test_causal_lm_batch_type(default_causal_lm):
|
||||||
|
|
||||||
|
|
||||||
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
||||||
|
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(
|
generated_texts, next_batch = default_causal_lm.generate_token(
|
||||||
default_causal_lm_batch
|
default_causal_lm_batch
|
||||||
)
|
)
|
||||||
|
@ -88,8 +88,12 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
||||||
assert isinstance(next_batch, CausalLMBatch)
|
assert isinstance(next_batch, CausalLMBatch)
|
||||||
|
|
||||||
assert len(next_batch.all_input_ids) == next_batch.size
|
assert len(next_batch.all_input_ids) == next_batch.size
|
||||||
assert len(next_batch.all_input_ids[0]) == len(next_batch.attention_mask[0]) == 9
|
assert (
|
||||||
assert next_batch.all_input_ids[0][-1] == 6208
|
len(next_batch.all_input_ids[0])
|
||||||
|
== len(next_batch.attention_mask[0])
|
||||||
|
== sequence_length + 1
|
||||||
|
)
|
||||||
|
assert next_batch.all_input_ids[0][-1] == 13
|
||||||
assert next_batch.all_input_ids[0][-2] == 14402
|
assert next_batch.all_input_ids[0][-2] == 14402
|
||||||
assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)
|
assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)
|
||||||
|
|
||||||
|
@ -97,14 +101,18 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
||||||
assert torch.all(next_batch.attention_mask[0][:-2] == 0)
|
assert torch.all(next_batch.attention_mask[0][:-2] == 0)
|
||||||
|
|
||||||
assert next_batch.input_ids.shape == (next_batch.size, 1)
|
assert next_batch.input_ids.shape == (next_batch.size, 1)
|
||||||
assert next_batch.input_ids[0, 0] == 6208
|
assert next_batch.input_ids[0, 0] == 13
|
||||||
|
|
||||||
assert next_batch.input_lengths == [2]
|
assert next_batch.input_lengths == [2]
|
||||||
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
|
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
|
||||||
|
|
||||||
assert next_batch.past_key_values is not None
|
assert next_batch.past_key_values is not None
|
||||||
assert all([p[0].shape == (1, 12, 8, 64) for p in next_batch.past_key_values])
|
assert all(
|
||||||
assert all([p[1].shape == (1, 12, 8, 64) for p in next_batch.past_key_values])
|
[p[0].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
[p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_causal_lm_generate_token_completion(
|
def test_causal_lm_generate_token_completion(
|
||||||
|
@ -119,10 +127,7 @@ def test_causal_lm_generate_token_completion(
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generated_texts) == 1
|
||||||
assert (
|
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
|
||||||
generated_texts[0].output
|
|
||||||
== "Test Test Test Test Test Test Test Test Test Test Test"
|
|
||||||
)
|
|
||||||
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
|
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].tokens
|
generated_texts[0].tokens
|
||||||
|
@ -145,7 +150,7 @@ def test_causal_lm_generate_token_completion_multi(
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generated_texts) == 1
|
||||||
assert generated_texts[0].output == "Test Test Test Test Test Test"
|
assert generated_texts[0].output == "Test.java:784)"
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
|
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
|
||||||
)
|
)
|
||||||
|
@ -166,10 +171,7 @@ def test_causal_lm_generate_token_completion_multi(
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generated_texts) == 1
|
||||||
assert (
|
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
|
||||||
generated_texts[0].output
|
|
||||||
== "Test Test Test Test Test Test Test Test Test Test Test"
|
|
||||||
)
|
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
|
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
|
||||||
)
|
)
|
||||||
|
@ -200,7 +202,8 @@ def test_batch_concatenate(
|
||||||
assert torch.all(next_batch.attention_mask[1:, :-2] == 0)
|
assert torch.all(next_batch.attention_mask[1:, :-2] == 0)
|
||||||
|
|
||||||
assert next_batch.batch_id == 0
|
assert next_batch.batch_id == 0
|
||||||
assert torch.all(next_batch.input_ids == 6208)
|
assert next_batch.input_ids[0, 0] == 12355
|
||||||
|
assert torch.all(next_batch.input_ids[1:] == 13)
|
||||||
|
|
||||||
assert next_batch.input_lengths == [3, 2, 2]
|
assert next_batch.input_lengths == [3, 2, 2]
|
||||||
assert next_batch.max_sequence_length == 3
|
assert next_batch.max_sequence_length == 3
|
||||||
|
@ -239,7 +242,7 @@ def test_batch_concatenate(
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generated_texts) == 1
|
||||||
assert generated_texts[0].output == "Test Test Test Test Test Test"
|
assert generated_texts[0].output == "Test.java:784)"
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
|
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
|
||||||
)
|
)
|
||||||
|
@ -260,10 +263,7 @@ def test_batch_concatenate(
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generated_texts) == 1
|
||||||
assert (
|
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
|
||||||
generated_texts[0].output
|
|
||||||
== "Test Test Test Test Test Test Test Test Test Test Test"
|
|
||||||
)
|
|
||||||
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
|
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].tokens
|
generated_texts[0].tokens
|
||||||
|
@ -283,10 +283,7 @@ def test_batch_concatenate(
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generated_texts) == 1
|
||||||
assert (
|
assert generated_texts[0].output == "Test.java:784) at net.minecraft."
|
||||||
generated_texts[0].output
|
|
||||||
== "Test Test Test Test Test Test Test Test Test Test Test"
|
|
||||||
)
|
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
|
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
|
||||||
)
|
)
|
||||||
|
|
|
@ -8,13 +8,13 @@ from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_pb_request(default_pb_parameters):
|
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
input_length=2,
|
input_length=2,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
max_new_tokens=10,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni
|
||||||
req_0 = copy(default_pb_request)
|
req_0 = copy(default_pb_request)
|
||||||
req_1 = default_pb_request
|
req_1 = default_pb_request
|
||||||
req_1.id = 1
|
req_1.id = 1
|
||||||
req_1.max_new_tokens = 5
|
req_1.stopping_parameters.max_new_tokens = 5
|
||||||
|
|
||||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||||
return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu"))
|
return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu"))
|
||||||
|
@ -48,11 +48,12 @@ def default_seq2seq_lm():
|
||||||
|
|
||||||
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
|
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
|
||||||
batch = default_seq2seq_lm_batch
|
batch = default_seq2seq_lm_batch
|
||||||
|
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
|
||||||
|
|
||||||
assert batch.batch_id == default_pb_batch.id
|
assert batch.batch_id == default_pb_batch.id
|
||||||
assert batch.requests == default_pb_batch.requests
|
assert batch.requests == default_pb_batch.requests
|
||||||
|
|
||||||
assert batch.input_ids.shape == (default_pb_batch.size, 8)
|
assert batch.input_ids.shape == (default_pb_batch.size, sequence_length)
|
||||||
assert batch.input_ids[0][-2] == 4268
|
assert batch.input_ids[0][-2] == 4268
|
||||||
assert batch.input_ids[0][-1] == 1
|
assert batch.input_ids[0][-1] == 1
|
||||||
assert torch.all(batch.input_ids[0][:-2] == 0)
|
assert torch.all(batch.input_ids[0][:-2] == 0)
|
||||||
|
@ -86,6 +87,7 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
|
||||||
|
|
||||||
|
|
||||||
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
|
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
|
||||||
|
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(
|
generated_texts, next_batch = default_seq2seq_lm.generate_token(
|
||||||
default_seq2seq_lm_batch
|
default_seq2seq_lm_batch
|
||||||
)
|
)
|
||||||
|
@ -108,7 +110,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
|
||||||
assert next_batch.decoder_input_ids[0, 0] == 0
|
assert next_batch.decoder_input_ids[0, 0] == 0
|
||||||
assert next_batch.decoder_input_ids[0, 1] == 259
|
assert next_batch.decoder_input_ids[0, 1] == 259
|
||||||
assert next_batch.decoder_attention_mask is None
|
assert next_batch.decoder_attention_mask is None
|
||||||
assert next_batch.encoder_last_hidden_state.shape == (1, 8, 512)
|
assert next_batch.encoder_last_hidden_state.shape == (1, sequence_length, 512)
|
||||||
|
|
||||||
assert next_batch.decoder_input_lengths == [2]
|
assert next_batch.decoder_input_lengths == [2]
|
||||||
assert next_batch.max_decoder_input_length == 2
|
assert next_batch.max_decoder_input_length == 2
|
||||||
|
@ -121,10 +123,16 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
|
||||||
[p[1].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values]
|
[p[1].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values]
|
||||||
)
|
)
|
||||||
assert all(
|
assert all(
|
||||||
[p[2].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values]
|
[
|
||||||
|
p[2].shape == (next_batch.size, 6, sequence_length, 64)
|
||||||
|
for p in next_batch.past_key_values
|
||||||
|
]
|
||||||
)
|
)
|
||||||
assert all(
|
assert all(
|
||||||
[p[3].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values]
|
[
|
||||||
|
p[3].shape == (next_batch.size, 6, sequence_length, 64)
|
||||||
|
for p in next_batch.past_key_values
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,10 +4,55 @@ from text_generation.utils import (
|
||||||
weight_hub_files,
|
weight_hub_files,
|
||||||
download_weights,
|
download_weights,
|
||||||
weight_files,
|
weight_files,
|
||||||
|
StopSequenceCriteria,
|
||||||
|
StoppingCriteria,
|
||||||
LocalEntryNotFoundError,
|
LocalEntryNotFoundError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_sequence_criteria():
|
||||||
|
criteria = StopSequenceCriteria([1, 2, 3])
|
||||||
|
|
||||||
|
assert not criteria(1)
|
||||||
|
assert criteria.current_token_idx == 1
|
||||||
|
assert not criteria(2)
|
||||||
|
assert criteria.current_token_idx == 2
|
||||||
|
assert criteria(3)
|
||||||
|
assert criteria.current_token_idx == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_sequence_criteria_reset():
|
||||||
|
criteria = StopSequenceCriteria([1, 2, 3])
|
||||||
|
|
||||||
|
assert not criteria(1)
|
||||||
|
assert criteria.current_token_idx == 1
|
||||||
|
assert not criteria(2)
|
||||||
|
assert criteria.current_token_idx == 2
|
||||||
|
assert not criteria(4)
|
||||||
|
assert criteria.current_token_idx == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_sequence_criteria_empty():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
StopSequenceCriteria([])
|
||||||
|
|
||||||
|
|
||||||
|
def test_stopping_criteria():
|
||||||
|
criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5)
|
||||||
|
assert criteria([1]) == (False, None)
|
||||||
|
assert criteria([1, 2]) == (False, None)
|
||||||
|
assert criteria([1, 2, 3]) == (True, "stop_sequence")
|
||||||
|
|
||||||
|
|
||||||
|
def test_stopping_criteria_max():
|
||||||
|
criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5)
|
||||||
|
assert criteria([1]) == (False, None)
|
||||||
|
assert criteria([1, 1]) == (False, None)
|
||||||
|
assert criteria([1, 1, 1]) == (False, None)
|
||||||
|
assert criteria([1, 1, 1, 1]) == (False, None)
|
||||||
|
assert criteria([1, 1, 1, 1, 1]) == (True, "length")
|
||||||
|
|
||||||
|
|
||||||
def test_weight_hub_files():
|
def test_weight_hub_files():
|
||||||
filenames = weight_hub_files("bigscience/bloom-560m")
|
filenames = weight_hub_files("bigscience/bloom-560m")
|
||||||
assert filenames == ["model.safetensors"]
|
assert filenames == ["model.safetensors"]
|
||||||
|
|
|
@ -57,23 +57,17 @@ class CausalLMBatch:
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
input_lengths.append(r.input_length)
|
input_lengths.append(r.input_length)
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
|
||||||
NextTokenChooser(
|
|
||||||
temperature=r.parameters.temperature,
|
|
||||||
top_k=r.parameters.top_k,
|
|
||||||
top_p=r.parameters.top_p,
|
|
||||||
do_sample=r.parameters.do_sample,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
stopping_criterias.append(
|
stopping_criterias.append(
|
||||||
StoppingCriteria(
|
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||||
eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
pad_to_multiple_of = 8 if "gpu" in str(device) else None
|
pad_to_multiple_of = 8 if "gpu" in str(device) else None
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of
|
inputs,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
pad_to_multiple_of=pad_to_multiple_of,
|
||||||
).to(device)
|
).to(device)
|
||||||
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
||||||
|
|
||||||
|
@ -123,8 +117,8 @@ class CausalLMBatch:
|
||||||
end_index = start_index + batch.size
|
end_index = start_index + batch.size
|
||||||
|
|
||||||
# We only concatenate batches that did at least one step
|
# We only concatenate batches that did at least one step
|
||||||
if batch.input_ids.shape[1] > 1:
|
if batch.past_key_values is None:
|
||||||
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
|
raise ValueError("only concatenate prefilled batches")
|
||||||
|
|
||||||
# Create empty tensor
|
# Create empty tensor
|
||||||
# input_ids is always of shape [batch_size, 1]
|
# input_ids is always of shape [batch_size, 1]
|
||||||
|
@ -331,14 +325,17 @@ class CausalLM(Model):
|
||||||
all_tokens = torch.cat([all_tokens, next_token])
|
all_tokens = torch.cat([all_tokens, next_token])
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
if stopping_criteria(all_tokens):
|
stop, reason = stopping_criteria(all_tokens)
|
||||||
|
if stop:
|
||||||
# Decode all tokens
|
# Decode all tokens
|
||||||
output = self.tokenizer.decode(
|
output = self.tokenizer.decode(
|
||||||
all_tokens.squeeze(-1), skip_special_tokens=True
|
all_tokens.squeeze(-1), skip_special_tokens=True
|
||||||
)
|
)
|
||||||
# Add to the list of finished generations with the original request
|
# Add to the list of finished generations with the original request
|
||||||
generated_texts.append(
|
generated_texts.append(
|
||||||
GeneratedText(request, output, stopping_criteria.current_tokens)
|
GeneratedText(
|
||||||
|
request, output, stopping_criteria.current_tokens, reason
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# add to the next batch
|
# add to the next batch
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -94,18 +94,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||||
input_lengths.append(r.input_length)
|
input_lengths.append(r.input_length)
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
|
||||||
NextTokenChooser(
|
|
||||||
temperature=r.parameters.temperature,
|
|
||||||
top_k=r.parameters.top_k,
|
|
||||||
top_p=r.parameters.top_p,
|
|
||||||
do_sample=r.parameters.do_sample,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
stopping_criterias.append(
|
stopping_criterias.append(
|
||||||
StoppingCriteria(
|
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||||
eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
|
@ -207,11 +198,7 @@ class GalacticaSharded(Galactica):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
module_name, param_name = name.rsplit(".", 1)
|
module_name, param_name = name.rsplit(".", 1)
|
||||||
try:
|
|
||||||
module = model.get_submodule(module_name)
|
module = model.get_submodule(module_name)
|
||||||
except Exception as e:
|
|
||||||
print(type(model), name, module_name, param_name)
|
|
||||||
raise e
|
|
||||||
current_tensor = parameters[name]
|
current_tensor = parameters[name]
|
||||||
|
|
||||||
slice_ = f.get_slice(name)
|
slice_ = f.get_slice(name)
|
||||||
|
|
|
@ -68,24 +68,18 @@ class Seq2SeqLMBatch:
|
||||||
# Decoder sequence only contains the bos_token
|
# Decoder sequence only contains the bos_token
|
||||||
decoder_input_ids.append(tokenizer.bos_token_id)
|
decoder_input_ids.append(tokenizer.bos_token_id)
|
||||||
decoder_input_lengths.append(1)
|
decoder_input_lengths.append(1)
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
|
||||||
NextTokenChooser(
|
|
||||||
temperature=r.parameters.temperature,
|
|
||||||
top_k=r.parameters.top_k,
|
|
||||||
top_p=r.parameters.top_p,
|
|
||||||
do_sample=r.parameters.do_sample,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
stopping_criterias.append(
|
stopping_criterias.append(
|
||||||
StoppingCriteria(
|
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||||
eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tokenize batch
|
# Tokenize batch
|
||||||
pad_to_multiple_of = 8 if "gpu" in str(device) else None
|
pad_to_multiple_of = 8 if "gpu" in str(device) else None
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of
|
inputs,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
pad_to_multiple_of=pad_to_multiple_of,
|
||||||
).to(device)
|
).to(device)
|
||||||
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
||||||
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
|
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
|
||||||
|
@ -431,12 +425,15 @@ class Seq2SeqLM(Model):
|
||||||
decoder_tokens = torch.cat([decoder_tokens, next_token.squeeze(1)])
|
decoder_tokens = torch.cat([decoder_tokens, next_token.squeeze(1)])
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
if stopping_criteria(decoder_tokens):
|
stop, reason = stopping_criteria(decoder_tokens)
|
||||||
|
if stop:
|
||||||
# Decode tokens
|
# Decode tokens
|
||||||
output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True)
|
output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True)
|
||||||
# Add to the list of finished generations with the original request
|
# Add to the list of finished generations with the original request
|
||||||
generated_texts.append(
|
generated_texts.append(
|
||||||
GeneratedText(request, output, stopping_criteria.current_tokens)
|
GeneratedText(
|
||||||
|
request, output, stopping_criteria.current_tokens, reason
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# add to the next batch
|
# add to the next batch
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -32,8 +32,12 @@ class GeneratedText:
|
||||||
request: generate_pb2.Request
|
request: generate_pb2.Request
|
||||||
output: str
|
output: str
|
||||||
tokens: int
|
tokens: int
|
||||||
|
reason: str
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.GeneratedText:
|
def to_pb(self) -> generate_pb2.GeneratedText:
|
||||||
return generate_pb2.GeneratedText(
|
return generate_pb2.GeneratedText(
|
||||||
request=self.request, output=self.output, tokens=self.tokens
|
request=self.request,
|
||||||
|
output=self.output,
|
||||||
|
tokens=self.tokens,
|
||||||
|
finish_reason=self.reason,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import concurrent
|
import concurrent
|
||||||
import os
|
import os
|
||||||
import signal
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
@ -11,6 +10,8 @@ from functools import partial
|
||||||
from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
|
from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
|
||||||
from huggingface_hub.utils import LocalEntryNotFoundError
|
from huggingface_hub.utils import LocalEntryNotFoundError
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
from transformers import AutoTokenizer
|
||||||
from transformers.generation.logits_process import (
|
from transformers.generation.logits_process import (
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
|
@ -18,6 +19,8 @@ from transformers.generation.logits_process import (
|
||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from text_generation.pb import generate_pb2
|
||||||
|
|
||||||
|
|
||||||
class Sampling:
|
class Sampling:
|
||||||
def __call__(self, logits):
|
def __call__(self, logits):
|
||||||
|
@ -56,20 +59,72 @@ class NextTokenChooser:
|
||||||
next_ids = self.choice(scores)
|
next_ids = self.choice(scores)
|
||||||
return next_ids.unsqueeze(-1)
|
return next_ids.unsqueeze(-1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pb(cls, pb: generate_pb2.LogitsWarperParameters) -> "NextTokenChooser":
|
||||||
|
return NextTokenChooser(
|
||||||
|
temperature=pb.temperature,
|
||||||
|
top_k=pb.top_k,
|
||||||
|
top_p=pb.top_p,
|
||||||
|
do_sample=pb.do_sample,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StopSequenceCriteria:
|
||||||
|
def __init__(self, tokens: List[int]):
|
||||||
|
if not tokens:
|
||||||
|
raise ValueError("tokens cannot be empty")
|
||||||
|
|
||||||
|
self.tokens = tokens
|
||||||
|
self.current_token_idx = 0
|
||||||
|
|
||||||
|
def __call__(self, last_token: int) -> bool:
|
||||||
|
if last_token == self.tokens[self.current_token_idx]:
|
||||||
|
# Increase idx to go to next token
|
||||||
|
self.current_token_idx += 1
|
||||||
|
else:
|
||||||
|
# Reset to first token of the stopping sequence
|
||||||
|
self.current_token_idx = 0
|
||||||
|
|
||||||
|
if self.current_token_idx == len(self.tokens):
|
||||||
|
# We matched the entire sequence without resetting
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class StoppingCriteria:
|
class StoppingCriteria:
|
||||||
def __init__(self, eos_token_id, max_new_tokens=20):
|
def __init__(
|
||||||
self.eos_token_id = eos_token_id
|
self, stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens=20
|
||||||
|
):
|
||||||
|
self.stop_sequence_criterias = stop_sequence_criterias
|
||||||
self.max_new_tokens = max_new_tokens
|
self.max_new_tokens = max_new_tokens
|
||||||
self.current_tokens = 0
|
self.current_tokens = 0
|
||||||
|
|
||||||
def __call__(self, all_ids):
|
def __call__(self, all_ids) -> Tuple[bool, Optional[str]]:
|
||||||
self.current_tokens += 1
|
self.current_tokens += 1
|
||||||
if self.current_tokens >= self.max_new_tokens:
|
if self.current_tokens >= self.max_new_tokens:
|
||||||
return True
|
return True, "length"
|
||||||
if self.eos_token_id is not None and all_ids[-1] == self.eos_token_id:
|
|
||||||
return True
|
last_token = all_ids[-1]
|
||||||
return False
|
for stop_sequence_criteria in self.stop_sequence_criterias:
|
||||||
|
if stop_sequence_criteria(last_token):
|
||||||
|
return True, "stop_sequence"
|
||||||
|
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pb(
|
||||||
|
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer
|
||||||
|
) -> "StoppingCriteria":
|
||||||
|
stop_sequence_criterias = []
|
||||||
|
for stop_sequence in pb.stop_sequences:
|
||||||
|
tokens = tokenizer(
|
||||||
|
stop_sequence, padding=False, return_attention_mask=False
|
||||||
|
).input_ids
|
||||||
|
if tokens:
|
||||||
|
stop_sequence_criterias.append(StopSequenceCriteria(tokens))
|
||||||
|
stop_sequence_criterias.append(StopSequenceCriteria([tokenizer.eos_token_id]))
|
||||||
|
|
||||||
|
return StoppingCriteria(stop_sequence_criterias, pb.max_new_tokens)
|
||||||
|
|
||||||
|
|
||||||
def initialize_torch_distributed():
|
def initialize_torch_distributed():
|
||||||
|
|
Loading…
Reference in New Issue