From 718096f6952147681559ba0a3962040f8655af1f Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Mon, 12 Dec 2022 18:25:22 +0100 Subject: [PATCH] feat: Support stop sequences (#7) --- README.md | 2 + proto/generate.proto | 17 ++++- router/client/src/lib.rs | 4 +- router/src/batcher.rs | 2 + router/src/db.rs | 17 ++++- router/src/lib.rs | 3 + router/src/server.rs | 9 ++- router/src/validation.rs | 10 +++ server/tests/conftest.py | 5 ++ server/tests/models/test_bloom.py | 22 +++++-- server/tests/models/test_causal_lm.py | 53 ++++++++------- server/tests/models/test_seq2seq_lm.py | 22 +++++-- server/tests/test_utils.py | 45 +++++++++++++ server/text_generation/models/causal_lm.py | 29 ++++----- server/text_generation/models/galactica.py | 19 +----- server/text_generation/models/seq2seq_lm.py | 25 ++++---- server/text_generation/models/types.py | 6 +- server/text_generation/utils.py | 71 ++++++++++++++++++--- 18 files changed, 254 insertions(+), 107 deletions(-) diff --git a/README.md b/README.md index bb58e281..92879707 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,8 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets. - Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) - [Safetensors](https://github.com/huggingface/safetensors) weight loading - 45ms per token generation for BLOOM with 8xA100 80GB +- Logits warpers (temperature scaling, topk ...) +- Stop sequences ## Officially supported models diff --git a/proto/generate.proto b/proto/generate.proto index 14f6f66b..0c67de03 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -28,12 +28,23 @@ message ClearCacheRequest {} message ClearCacheResponse {} message LogitsWarperParameters { + /// exponential scaling output probability distribution float temperature = 1; + /// restricting to the k highest probability elements uint32 top_k = 2; + /// restricting to top tokens summing to prob_cut_off <= prob_cut_off float top_p = 3; + /// apply sampling on the logits bool do_sample = 4; } +message StoppingCriteriaParameters { + /// Maximum number of generated tokens + uint32 max_new_tokens = 1; + /// Optional stopping sequences + repeated string stop_sequences = 2; +} + message Request { /// Request ID uint64 id = 1; @@ -43,8 +54,8 @@ message Request { uint32 input_length = 3; /// Logits Warper Parameters LogitsWarperParameters parameters = 4; - /// Stopping criteria - uint32 max_new_tokens = 5; + /// Stopping Criteria Parameters + StoppingCriteriaParameters stopping_parameters = 5; } message Batch { @@ -63,6 +74,8 @@ message GeneratedText { string output = 2; /// Number of generated tokens uint32 tokens = 3; + /// Finish reason + string finish_reason = 4; } message GenerateRequest { diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 0f1f96bc..ae337dd6 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -6,7 +6,9 @@ mod pb; mod sharded_client; pub use client::Client; -pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request}; +pub use pb::generate::v1::{ + Batch, GeneratedText, LogitsWarperParameters, Request, StoppingCriteriaParameters, +}; pub use sharded_client::ShardedClient; use thiserror::Error; use tonic::transport; diff --git a/router/src/batcher.rs b/router/src/batcher.rs index 0c85a406..a72a6e44 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -190,6 +190,7 @@ fn send_generated(finished: Vec, db: &Db) { let response = InferResponse { output: output.output, tokens: output.tokens, + finish_reason: output.finish_reason, queued: entry.time, start: entry.batch_time.unwrap(), // unwrap is always valid end: Instant::now(), @@ -203,6 +204,7 @@ fn send_generated(finished: Vec, db: &Db) { pub(crate) struct InferResponse { pub(crate) output: String, pub(crate) tokens: u32, + pub(crate) finish_reason: String, pub(crate) queued: Instant, pub(crate) start: Instant, pub(crate) end: Instant, diff --git a/router/src/db.rs b/router/src/db.rs index 0701206b..24fb7a09 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -4,7 +4,9 @@ use crate::{GenerateParameters, GenerateRequest}; use parking_lot::Mutex; use std::collections::BTreeMap; use std::sync::Arc; -use text_generation_client::{Batch, ClientError, LogitsWarperParameters, Request}; +use text_generation_client::{ + Batch, ClientError, LogitsWarperParameters, Request, StoppingCriteriaParameters, +}; use tokio::sync::oneshot::Sender; use tokio::time::Instant; @@ -72,7 +74,9 @@ impl State { parameters: Some(LogitsWarperParameters::from( entry.request.parameters.clone(), )), - max_new_tokens: entry.request.parameters.max_new_tokens, + stopping_parameters: Some(StoppingCriteriaParameters::from( + entry.request.parameters.clone(), + )), }); ids.push(*id); @@ -168,3 +172,12 @@ impl From for LogitsWarperParameters { } } } + +impl From for StoppingCriteriaParameters { + fn from(parameters: GenerateParameters) -> Self { + Self { + stop_sequences: parameters.stop, + max_new_tokens: parameters.max_new_tokens, + } + } +} diff --git a/router/src/lib.rs b/router/src/lib.rs index 8646ad59..b6c694ee 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -21,6 +21,7 @@ pub(crate) struct GenerateParameters { pub do_sample: bool, #[serde(default = "default_max_new_tokens")] pub max_new_tokens: u32, + pub stop: Vec, } fn default_temperature() -> f32 { @@ -50,6 +51,7 @@ fn default_parameters() -> GenerateParameters { top_p: default_top_p(), do_sample: default_do_sample(), max_new_tokens: default_max_new_tokens(), + stop: vec![], } } @@ -63,6 +65,7 @@ pub(crate) struct GenerateRequest { #[derive(Serialize)] pub(crate) struct GeneratedText { pub generated_text: String, + pub finish_reason: String, } #[derive(Serialize)] diff --git a/router/src/server.rs b/router/src/server.rs index 9f4a75c9..59296269 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -53,6 +53,7 @@ async fn health(state: Extension) -> Result<(), (StatusCode, Json 4 { + response_tx + .send(Err(ValidationError::StopSequence( + request.parameters.stop.len(), + ))) + .unwrap_or(()); + continue; + } // Get the number of tokens in the input match tokenizer.encode(request.inputs.clone(), false) { @@ -163,6 +171,8 @@ pub enum ValidationError { MaxNewTokens, #[error("inputs must have less than {1} tokens. Given: {0}")] InputLength(usize, usize), + #[error("stop supports up to 4 stop sequences. Given: {0}")] + StopSequence(usize), #[error("tokenizer error {0}")] Tokenizer(String), } diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 0640d45d..24cdafac 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -15,6 +15,11 @@ def default_pb_parameters(): ) +@pytest.fixture +def default_pb_stop_parameters(): + return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10) + + @pytest.fixture(scope="session") def bloom_560m_tokenizer(): return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left") diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 49dabb14..c5dbaa3e 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -9,13 +9,13 @@ from text_generation.models.bloom import BloomCausalLMBatch, BLOOM @pytest.fixture -def default_pb_request(default_pb_parameters): +def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", input_length=1, parameters=default_pb_parameters, - max_new_tokens=10, + stopping_parameters=default_pb_stop_parameters, ) @@ -36,7 +36,7 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer) req_0 = copy(default_pb_request) req_1 = default_pb_request req_1.id = 1 - req_1.max_new_tokens = 5 + req_1.stopping_parameters.max_new_tokens = 5 batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) return BloomCausalLMBatch.from_pb( @@ -56,7 +56,6 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch): assert batch.requests == default_pb_batch.requests assert len(batch.input_ids) == default_pb_batch.size - assert len(batch.input_ids[0]) == 8 assert batch.input_ids[0][-1] == 10264 assert torch.all(batch.input_ids[0][:-1] == 3) @@ -85,6 +84,7 @@ def test_causal_lm_batch_type(default_bloom): def test_causal_lm_generate_token(default_bloom, default_bloom_batch): + sequence_length = len(default_bloom_batch.all_input_ids[0]) generated_texts, next_batch = default_bloom.generate_token(default_bloom_batch) assert generated_texts == [] @@ -92,7 +92,11 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch): assert not next_batch.keys_head_dim_last assert len(next_batch.all_input_ids) == next_batch.size - assert len(next_batch.all_input_ids[0]) == len(next_batch.attention_mask[0]) == 9 + assert ( + len(next_batch.all_input_ids[0]) + == len(next_batch.attention_mask[0]) + == sequence_length + 1 + ) assert torch.all(next_batch.all_input_ids[0][-2:] == 10264) assert torch.all(next_batch.all_input_ids[0][:-2] == 3) @@ -106,8 +110,12 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch): assert next_batch.max_sequence_length == next_batch.input_lengths[0] assert next_batch.past_key_values is not None - assert all([p[0].shape == (16, 64, 8) for p in next_batch.past_key_values]) - assert all([p[1].shape == (16, 8, 64) for p in next_batch.past_key_values]) + assert all( + [p[0].shape == (16, 64, sequence_length) for p in next_batch.past_key_values] + ) + assert all( + [p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values] + ) def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch): diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 1bf3e5e6..f38776cc 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -8,13 +8,13 @@ from text_generation.models.causal_lm import CausalLM, CausalLMBatch @pytest.fixture -def default_pb_request(default_pb_parameters): +def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", input_length=1, parameters=default_pb_parameters, - max_new_tokens=10, + stopping_parameters=default_pb_stop_parameters, ) @@ -33,7 +33,7 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer): req_0 = copy(default_pb_request) req_1 = default_pb_request req_1.id = 1 - req_1.max_new_tokens = 5 + req_1.stopping_parameters.max_new_tokens = 5 batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu")) @@ -51,7 +51,6 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch): assert batch.requests == default_pb_batch.requests assert len(batch.input_ids) == default_pb_batch.size - assert len(batch.input_ids[0]) == 8 assert batch.input_ids[0][-1] == 14402 assert torch.all(batch.input_ids[0][:-1] == 50256) @@ -80,6 +79,7 @@ def test_causal_lm_batch_type(default_causal_lm): def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): + sequence_length = len(default_causal_lm_batch.all_input_ids[0]) generated_texts, next_batch = default_causal_lm.generate_token( default_causal_lm_batch ) @@ -88,8 +88,12 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): assert isinstance(next_batch, CausalLMBatch) assert len(next_batch.all_input_ids) == next_batch.size - assert len(next_batch.all_input_ids[0]) == len(next_batch.attention_mask[0]) == 9 - assert next_batch.all_input_ids[0][-1] == 6208 + assert ( + len(next_batch.all_input_ids[0]) + == len(next_batch.attention_mask[0]) + == sequence_length + 1 + ) + assert next_batch.all_input_ids[0][-1] == 13 assert next_batch.all_input_ids[0][-2] == 14402 assert torch.all(next_batch.all_input_ids[0][:-2] == 50256) @@ -97,14 +101,18 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): assert torch.all(next_batch.attention_mask[0][:-2] == 0) assert next_batch.input_ids.shape == (next_batch.size, 1) - assert next_batch.input_ids[0, 0] == 6208 + assert next_batch.input_ids[0, 0] == 13 assert next_batch.input_lengths == [2] assert next_batch.max_sequence_length == next_batch.input_lengths[0] assert next_batch.past_key_values is not None - assert all([p[0].shape == (1, 12, 8, 64) for p in next_batch.past_key_values]) - assert all([p[1].shape == (1, 12, 8, 64) for p in next_batch.past_key_values]) + assert all( + [p[0].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values] + ) + assert all( + [p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values] + ) def test_causal_lm_generate_token_completion( @@ -119,10 +127,7 @@ def test_causal_lm_generate_token_completion( assert next_batch is None assert len(generated_texts) == 1 - assert ( - generated_texts[0].output - == "Test Test Test Test Test Test Test Test Test Test Test" - ) + assert generated_texts[0].output == "Test.java:784) at net.minecraft." assert generated_texts[0].request == default_causal_lm_batch.requests[0] assert ( generated_texts[0].tokens @@ -145,7 +150,7 @@ def test_causal_lm_generate_token_completion_multi( assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "Test Test Test Test Test Test" + assert generated_texts[0].output == "Test.java:784)" assert ( generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] ) @@ -166,10 +171,7 @@ def test_causal_lm_generate_token_completion_multi( assert next_batch is None assert len(generated_texts) == 1 - assert ( - generated_texts[0].output - == "Test Test Test Test Test Test Test Test Test Test Test" - ) + assert generated_texts[0].output == "Test.java:784) at net.minecraft." assert ( generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] ) @@ -200,7 +202,8 @@ def test_batch_concatenate( assert torch.all(next_batch.attention_mask[1:, :-2] == 0) assert next_batch.batch_id == 0 - assert torch.all(next_batch.input_ids == 6208) + assert next_batch.input_ids[0, 0] == 12355 + assert torch.all(next_batch.input_ids[1:] == 13) assert next_batch.input_lengths == [3, 2, 2] assert next_batch.max_sequence_length == 3 @@ -239,7 +242,7 @@ def test_batch_concatenate( assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "Test Test Test Test Test Test" + assert generated_texts[0].output == "Test.java:784)" assert ( generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] ) @@ -260,10 +263,7 @@ def test_batch_concatenate( assert next_batch is not None assert len(generated_texts) == 1 - assert ( - generated_texts[0].output - == "Test Test Test Test Test Test Test Test Test Test Test" - ) + assert generated_texts[0].output == "Test.java:784) at net.minecraft." assert generated_texts[0].request == default_causal_lm_batch.requests[0] assert ( generated_texts[0].tokens @@ -283,10 +283,7 @@ def test_batch_concatenate( assert next_batch is None assert len(generated_texts) == 1 - assert ( - generated_texts[0].output - == "Test Test Test Test Test Test Test Test Test Test Test" - ) + assert generated_texts[0].output == "Test.java:784) at net.minecraft." assert ( generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] ) diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 7e4c7fdd..94ec70d5 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -8,13 +8,13 @@ from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch @pytest.fixture -def default_pb_request(default_pb_parameters): +def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", input_length=2, parameters=default_pb_parameters, - max_new_tokens=10, + stopping_parameters=default_pb_stop_parameters, ) @@ -35,7 +35,7 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni req_0 = copy(default_pb_request) req_1 = default_pb_request req_1.id = 1 - req_1.max_new_tokens = 5 + req_1.stopping_parameters.max_new_tokens = 5 batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu")) @@ -48,11 +48,12 @@ def default_seq2seq_lm(): def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch): batch = default_seq2seq_lm_batch + sequence_length = len(default_seq2seq_lm_batch.input_ids[0]) assert batch.batch_id == default_pb_batch.id assert batch.requests == default_pb_batch.requests - assert batch.input_ids.shape == (default_pb_batch.size, 8) + assert batch.input_ids.shape == (default_pb_batch.size, sequence_length) assert batch.input_ids[0][-2] == 4268 assert batch.input_ids[0][-1] == 1 assert torch.all(batch.input_ids[0][:-2] == 0) @@ -86,6 +87,7 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm): def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch): + sequence_length = len(default_seq2seq_lm_batch.input_ids[0]) generated_texts, next_batch = default_seq2seq_lm.generate_token( default_seq2seq_lm_batch ) @@ -108,7 +110,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch) assert next_batch.decoder_input_ids[0, 0] == 0 assert next_batch.decoder_input_ids[0, 1] == 259 assert next_batch.decoder_attention_mask is None - assert next_batch.encoder_last_hidden_state.shape == (1, 8, 512) + assert next_batch.encoder_last_hidden_state.shape == (1, sequence_length, 512) assert next_batch.decoder_input_lengths == [2] assert next_batch.max_decoder_input_length == 2 @@ -121,10 +123,16 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch) [p[1].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values] ) assert all( - [p[2].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values] + [ + p[2].shape == (next_batch.size, 6, sequence_length, 64) + for p in next_batch.past_key_values + ] ) assert all( - [p[3].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values] + [ + p[3].shape == (next_batch.size, 6, sequence_length, 64) + for p in next_batch.past_key_values + ] ) diff --git a/server/tests/test_utils.py b/server/tests/test_utils.py index e630ebda..ed6e896a 100644 --- a/server/tests/test_utils.py +++ b/server/tests/test_utils.py @@ -4,10 +4,55 @@ from text_generation.utils import ( weight_hub_files, download_weights, weight_files, + StopSequenceCriteria, + StoppingCriteria, LocalEntryNotFoundError, ) +def test_stop_sequence_criteria(): + criteria = StopSequenceCriteria([1, 2, 3]) + + assert not criteria(1) + assert criteria.current_token_idx == 1 + assert not criteria(2) + assert criteria.current_token_idx == 2 + assert criteria(3) + assert criteria.current_token_idx == 3 + + +def test_stop_sequence_criteria_reset(): + criteria = StopSequenceCriteria([1, 2, 3]) + + assert not criteria(1) + assert criteria.current_token_idx == 1 + assert not criteria(2) + assert criteria.current_token_idx == 2 + assert not criteria(4) + assert criteria.current_token_idx == 0 + + +def test_stop_sequence_criteria_empty(): + with pytest.raises(ValueError): + StopSequenceCriteria([]) + + +def test_stopping_criteria(): + criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5) + assert criteria([1]) == (False, None) + assert criteria([1, 2]) == (False, None) + assert criteria([1, 2, 3]) == (True, "stop_sequence") + + +def test_stopping_criteria_max(): + criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5) + assert criteria([1]) == (False, None) + assert criteria([1, 1]) == (False, None) + assert criteria([1, 1, 1]) == (False, None) + assert criteria([1, 1, 1, 1]) == (False, None) + assert criteria([1, 1, 1, 1, 1]) == (True, "length") + + def test_weight_hub_files(): filenames = weight_hub_files("bigscience/bloom-560m") assert filenames == ["model.safetensors"] diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 1cd999f0..72095858 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -57,23 +57,17 @@ class CausalLMBatch: for r in pb.requests: inputs.append(r.inputs) input_lengths.append(r.input_length) - next_token_choosers.append( - NextTokenChooser( - temperature=r.parameters.temperature, - top_k=r.parameters.top_k, - top_p=r.parameters.top_p, - do_sample=r.parameters.do_sample, - ) - ) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters)) stopping_criterias.append( - StoppingCriteria( - eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens - ) + StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) ) pad_to_multiple_of = 8 if "gpu" in str(device) else None tokenized_inputs = tokenizer( - inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of + inputs, + return_tensors="pt", + padding=True, + pad_to_multiple_of=pad_to_multiple_of, ).to(device) all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) @@ -123,8 +117,8 @@ class CausalLMBatch: end_index = start_index + batch.size # We only concatenate batches that did at least one step - if batch.input_ids.shape[1] > 1: - raise ValueError("Batch input_ids should be of shape (batch_size, 1)") + if batch.past_key_values is None: + raise ValueError("only concatenate prefilled batches") # Create empty tensor # input_ids is always of shape [batch_size, 1] @@ -331,14 +325,17 @@ class CausalLM(Model): all_tokens = torch.cat([all_tokens, next_token]) # Evaluate stopping criteria - if stopping_criteria(all_tokens): + stop, reason = stopping_criteria(all_tokens) + if stop: # Decode all tokens output = self.tokenizer.decode( all_tokens.squeeze(-1), skip_special_tokens=True ) # Add to the list of finished generations with the original request generated_texts.append( - GeneratedText(request, output, stopping_criteria.current_tokens) + GeneratedText( + request, output, stopping_criteria.current_tokens, reason + ) ) # add to the next batch else: diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index 8aec1bc7..680ea43e 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -94,18 +94,9 @@ class GalacticaCausalLMBatch(CausalLMBatch): # Add escape_custom_split_sequence to the CausalLMBatch logic inputs.append(escape_custom_split_sequence(r.inputs)) input_lengths.append(r.input_length) - next_token_choosers.append( - NextTokenChooser( - temperature=r.parameters.temperature, - top_k=r.parameters.top_k, - top_p=r.parameters.top_p, - do_sample=r.parameters.do_sample, - ) - ) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters)) stopping_criterias.append( - StoppingCriteria( - eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens - ) + StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) ) tokenized_inputs = tokenizer( @@ -207,11 +198,7 @@ class GalacticaSharded(Galactica): continue module_name, param_name = name.rsplit(".", 1) - try: - module = model.get_submodule(module_name) - except Exception as e: - print(type(model), name, module_name, param_name) - raise e + module = model.get_submodule(module_name) current_tensor = parameters[name] slice_ = f.get_slice(name) diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index cbfe7ccf..93e31b4a 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -68,24 +68,18 @@ class Seq2SeqLMBatch: # Decoder sequence only contains the bos_token decoder_input_ids.append(tokenizer.bos_token_id) decoder_input_lengths.append(1) - next_token_choosers.append( - NextTokenChooser( - temperature=r.parameters.temperature, - top_k=r.parameters.top_k, - top_p=r.parameters.top_p, - do_sample=r.parameters.do_sample, - ) - ) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters)) stopping_criterias.append( - StoppingCriteria( - eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens - ) + StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) ) # Tokenize batch pad_to_multiple_of = 8 if "gpu" in str(device) else None tokenized_inputs = tokenizer( - inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of + inputs, + return_tensors="pt", + padding=True, + pad_to_multiple_of=pad_to_multiple_of, ).to(device) # Convert decoder_input_ids to torch tensor of size [batch_size, 1] decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1) @@ -431,12 +425,15 @@ class Seq2SeqLM(Model): decoder_tokens = torch.cat([decoder_tokens, next_token.squeeze(1)]) # Evaluate stopping criteria - if stopping_criteria(decoder_tokens): + stop, reason = stopping_criteria(decoder_tokens) + if stop: # Decode tokens output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True) # Add to the list of finished generations with the original request generated_texts.append( - GeneratedText(request, output, stopping_criteria.current_tokens) + GeneratedText( + request, output, stopping_criteria.current_tokens, reason + ) ) # add to the next batch else: diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py index 7c25bf67..91ec75d0 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation/models/types.py @@ -32,8 +32,12 @@ class GeneratedText: request: generate_pb2.Request output: str tokens: int + reason: str def to_pb(self) -> generate_pb2.GeneratedText: return generate_pb2.GeneratedText( - request=self.request, output=self.output, tokens=self.tokens + request=self.request, + output=self.output, + tokens=self.tokens, + finish_reason=self.reason, ) diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index e55eeb64..661df05e 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -1,6 +1,5 @@ import concurrent import os -import signal import torch import torch.distributed @@ -11,6 +10,8 @@ from functools import partial from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache from huggingface_hub.utils import LocalEntryNotFoundError from tqdm import tqdm +from typing import List, Optional, Tuple +from transformers import AutoTokenizer from transformers.generation.logits_process import ( LogitsProcessorList, TemperatureLogitsWarper, @@ -18,6 +19,8 @@ from transformers.generation.logits_process import ( TopKLogitsWarper, ) +from text_generation.pb import generate_pb2 + class Sampling: def __call__(self, logits): @@ -56,20 +59,72 @@ class NextTokenChooser: next_ids = self.choice(scores) return next_ids.unsqueeze(-1) + @classmethod + def from_pb(cls, pb: generate_pb2.LogitsWarperParameters) -> "NextTokenChooser": + return NextTokenChooser( + temperature=pb.temperature, + top_k=pb.top_k, + top_p=pb.top_p, + do_sample=pb.do_sample, + ) + + +class StopSequenceCriteria: + def __init__(self, tokens: List[int]): + if not tokens: + raise ValueError("tokens cannot be empty") + + self.tokens = tokens + self.current_token_idx = 0 + + def __call__(self, last_token: int) -> bool: + if last_token == self.tokens[self.current_token_idx]: + # Increase idx to go to next token + self.current_token_idx += 1 + else: + # Reset to first token of the stopping sequence + self.current_token_idx = 0 + + if self.current_token_idx == len(self.tokens): + # We matched the entire sequence without resetting + return True + return False + class StoppingCriteria: - def __init__(self, eos_token_id, max_new_tokens=20): - self.eos_token_id = eos_token_id + def __init__( + self, stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens=20 + ): + self.stop_sequence_criterias = stop_sequence_criterias self.max_new_tokens = max_new_tokens self.current_tokens = 0 - def __call__(self, all_ids): + def __call__(self, all_ids) -> Tuple[bool, Optional[str]]: self.current_tokens += 1 if self.current_tokens >= self.max_new_tokens: - return True - if self.eos_token_id is not None and all_ids[-1] == self.eos_token_id: - return True - return False + return True, "length" + + last_token = all_ids[-1] + for stop_sequence_criteria in self.stop_sequence_criterias: + if stop_sequence_criteria(last_token): + return True, "stop_sequence" + + return False, None + + @classmethod + def from_pb( + cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer + ) -> "StoppingCriteria": + stop_sequence_criterias = [] + for stop_sequence in pb.stop_sequences: + tokens = tokenizer( + stop_sequence, padding=False, return_attention_mask=False + ).input_ids + if tokens: + stop_sequence_criterias.append(StopSequenceCriteria(tokens)) + stop_sequence_criterias.append(StopSequenceCriteria([tokenizer.eos_token_id])) + + return StoppingCriteria(stop_sequence_criterias, pb.max_new_tokens) def initialize_torch_distributed():