From b49dbf2d88c340c3686e4318985d8b64581364b7 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 16 Mar 2023 12:12:26 +0100 Subject: [PATCH] fix(server): use server tokenizer as gt (#128) --- proto/generate.proto | 6 +-- router/src/infer.rs | 6 ++- router/src/queue.rs | 5 +- router/src/validation.rs | 2 - server/tests/models/test_bloom.py | 7 ++- server/tests/models/test_causal_lm.py | 7 ++- server/tests/models/test_santacoder.py | 2 - server/tests/models/test_seq2seq_lm.py | 1 - .../models/causal_lm.py | 52 +++++++++---------- .../models/seq2seq_lm.py | 12 ++--- 10 files changed, 46 insertions(+), 54 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index a47e2ec1..5081ce1c 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -58,12 +58,10 @@ message Request { uint64 id = 1; /// The generation context string inputs = 2; - /// The number of tokens inside inputs - uint32 input_length = 3; /// Next Token Chooser Parameters - NextTokenChooserParameters parameters = 4; + NextTokenChooserParameters parameters = 3; /// Stopping Criteria Parameters - StoppingCriteriaParameters stopping_parameters = 5; + StoppingCriteriaParameters stopping_parameters = 4; } message Batch { diff --git a/router/src/infer.rs b/router/src/infer.rs index e11f4fe6..ae151d8a 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -278,7 +278,8 @@ async fn batching_task( // because a new batch is being computed let entry_waiting_span = info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size); - // Add relationship + // Add relationships + span.follows_from(&entry_waiting_span); entry_waiting_span.follows_from(&span); // Update entry entry.temp_span = Some(entry_waiting_span); @@ -305,7 +306,8 @@ async fn batching_task( // Create a new span to link the batch back to this entry let entry_batch_span = info_span!(parent: &entry.span, "infer", batch_size = next_batch_size); - // Add relationship + // Add relationships + next_batch_span.follows_from(&entry_batch_span); entry_batch_span.follows_from(&next_batch_span); // Update entry entry.temp_span = Some(entry_batch_span); diff --git a/router/src/queue.rs b/router/src/queue.rs index db3c509e..df2087e1 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -165,7 +165,8 @@ impl State { // Create a new span to link the batch back to this entry let entry_batch_span = info_span!(parent: &entry.span, "infer", batch_size = next_batch_size); - // Add relationship + // Add relationships + next_batch_span.follows_from(&entry_batch_span); entry_batch_span.follows_from(&next_batch_span); // Update entry entry.temp_span = Some(entry_batch_span); @@ -173,7 +174,6 @@ impl State { batch_requests.push(Request { id, inputs: entry.request.inputs.clone(), - input_length: entry.request.input_length, parameters: Some(entry.request.parameters.clone()), stopping_parameters: Some(entry.request.stopping_parameters.clone()), }); @@ -226,7 +226,6 @@ mod tests { Entry { request: ValidGenerateRequest { inputs: "".to_string(), - input_length: 0, parameters: NextTokenChooserParameters { temperature: 0.0, top_k: 0, diff --git a/router/src/validation.rs b/router/src/validation.rs index cb8dd0a2..1c350caa 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -322,7 +322,6 @@ fn validate( Ok(ValidGenerateRequest { inputs, - input_length: input_length as u32, parameters, stopping_parameters, }) @@ -337,7 +336,6 @@ type ValidationRequest = ( #[derive(Debug)] pub(crate) struct ValidGenerateRequest { pub inputs: String, - pub input_length: u32, pub parameters: NextTokenChooserParameters, pub stopping_parameters: StoppingCriteriaParameters, } diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 90239f95..2b8ef5f8 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -24,7 +24,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", - input_length=1, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, ) @@ -77,7 +76,7 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch): assert batch.size == default_pb_batch.size assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size - assert batch.max_sequence_length == batch.input_lengths[0] + assert batch.max_input_length == batch.input_lengths[0] def test_batch_concatenate_no_prefill(default_bloom_batch): @@ -110,7 +109,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch): assert next_batch.input_ids[0, 0] == 10264 assert next_batch.input_lengths == [2] - assert next_batch.max_sequence_length == next_batch.input_lengths[0] + assert next_batch.max_input_length == next_batch.input_lengths[0] assert next_batch.past_key_values is not None assert all( @@ -222,7 +221,7 @@ def test_batch_concatenate( assert torch.all(next_batch.input_ids == 10264) assert next_batch.input_lengths == [3, 2, 2] - assert next_batch.max_sequence_length == 3 + assert next_batch.max_input_length == 3 assert next_batch.requests[0] == next_batch_0.requests[0] assert next_batch.requests[1:] == next_batch_1.requests diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 869022fa..76617b62 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -25,7 +25,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", - input_length=1, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, ) @@ -74,7 +73,7 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch): assert batch.size == default_pb_batch.size assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size - assert batch.max_sequence_length == batch.input_lengths[0] + assert batch.max_input_length == batch.input_lengths[0] def test_batch_concatenate_no_prefill(default_causal_lm_batch): @@ -107,7 +106,7 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): assert next_batch.input_ids[0, 0] == 13 assert next_batch.input_lengths == [2] - assert next_batch.max_sequence_length == next_batch.input_lengths[0] + assert next_batch.max_input_length == next_batch.input_lengths[0] assert next_batch.past_key_values is not None assert all( @@ -220,7 +219,7 @@ def test_batch_concatenate( assert torch.all(next_batch.input_ids[1:] == 13) assert next_batch.input_lengths == [3, 2, 2] - assert next_batch.max_sequence_length == 3 + assert next_batch.max_input_length == 3 assert next_batch.requests[0] == next_batch_0.requests[0] assert next_batch.requests[1:] == next_batch_1.requests diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index d089def3..753ff5fc 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -15,7 +15,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="def", - input_length=1, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, ) @@ -31,7 +30,6 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="defworld", - input_length=5, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, ) diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 764f8f83..2d86c44b 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -28,7 +28,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", - input_length=2, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, ) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index c979b7bc..88ea6c75 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -41,7 +41,7 @@ class CausalLMBatch(Batch): # Metadata used for padding size: int - max_sequence_length: int + max_input_length: int padding_right_offset: int # Past metadata @@ -67,17 +67,14 @@ class CausalLMBatch(Batch): input_lengths = [] # Parse batch - max_sequence_length = 0 padding_right_offset = 0 for r in pb.requests: inputs.append(r.inputs) - input_lengths.append(r.input_length) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) stopping_criterias.append(stopping_criteria) - max_sequence_length = max(max_sequence_length, r.input_length) padding_right_offset = max( padding_right_offset, stopping_criteria.max_new_tokens ) @@ -89,13 +86,16 @@ class CausalLMBatch(Batch): return_token_type_ids=False, ).to(device) + input_lengths = tokenized_inputs["attention_mask"].sum(1) + max_input_length = input_lengths.max() + input_ids = tokenized_inputs["input_ids"] # Allocate maximum attention_mask attention_mask = input_ids.new_zeros( - (pb.size, max_sequence_length + padding_right_offset) + (pb.size, max_input_length + padding_right_offset) ) # Copy tokenizer attention_mask into fully allocated attention_mask - attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"] + attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) @@ -109,11 +109,11 @@ class CausalLMBatch(Batch): position_ids=position_ids, past_key_values=None, all_input_ids=all_input_ids, - input_lengths=input_lengths, + input_lengths=input_lengths.tolist(), next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=pb.size, - max_sequence_length=max_sequence_length, + max_input_length=max_input_length.item(), padding_right_offset=padding_right_offset, ) @@ -122,11 +122,11 @@ class CausalLMBatch(Batch): def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": # Used for padding total_batch_size = 0 - max_sequence_length = 0 + max_input_length = 0 padding_right_offset = 0 for batch in batches: total_batch_size += batch.size - max_sequence_length = max(max_sequence_length, batch.max_sequence_length) + max_input_length = max(max_input_length, batch.max_input_length) padding_right_offset = max(padding_right_offset, batch.padding_right_offset) # Batch attributes @@ -170,15 +170,15 @@ class CausalLMBatch(Batch): # Create padded tensor if attention_mask is None: attention_mask = batch.attention_mask.new_zeros( - (total_batch_size, max_sequence_length + padding_right_offset), + (total_batch_size, max_input_length + padding_right_offset), ) # We need to slice the attention mask to remove padding from previous steps # and to remove unused allocated space - left_offset = max_sequence_length - batch.max_sequence_length + left_offset = max_input_length - batch.max_input_length batch_left_offset = ( batch.attention_mask.shape[1] - - batch.max_sequence_length + - batch.max_input_length - batch.padding_right_offset ) attention_mask[ @@ -209,7 +209,7 @@ class CausalLMBatch(Batch): padded_past_values_shape = ( total_batch_size, num_heads, - max_sequence_length - 1, + max_input_length - 1, head_dim, ) @@ -221,7 +221,7 @@ class CausalLMBatch(Batch): total_batch_size, num_heads, head_dim, - max_sequence_length - 1, + max_input_length - 1, ) # This will run only once per layer @@ -235,20 +235,20 @@ class CausalLMBatch(Batch): past_key_values[j][0][ start_index:end_index, :, - -(batch.max_sequence_length - 1) :, + -(batch.max_input_length - 1) :, :, - ] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :] + ] = past_keys[:, :, -(batch.max_input_length - 1) :, :] else: past_key_values[j][0][ start_index:end_index, :, :, - -(batch.max_sequence_length - 1) :, - ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :] + -(batch.max_input_length - 1) :, + ] = past_keys[:, :, :, -(batch.max_input_length - 1) :] past_key_values[j][1][ - start_index:end_index, :, -(batch.max_sequence_length - 1) :, : - ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :] + start_index:end_index, :, -(batch.max_input_length - 1) :, : + ] = past_values[:, :, -(batch.max_input_length - 1) :, :] start_index += batch.size @@ -264,7 +264,7 @@ class CausalLMBatch(Batch): next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=total_batch_size, - max_sequence_length=max_sequence_length, + max_input_length=max_input_length, padding_right_offset=padding_right_offset, keys_head_dim_last=batches[0].keys_head_dim_last, ) @@ -352,7 +352,7 @@ class CausalLM(Model): # Metadata next_batch_size = 0 - next_batch_max_sequence_length = 0 + next_batch_max_input_length = 0 # Results generations: List[Generation] = [] @@ -420,8 +420,8 @@ class CausalLM(Model): next_batch_all_input_ids.append(all_input_ids) next_batch_size += 1 next_batch_input_lengths.append(new_input_length) - next_batch_max_sequence_length = max( - next_batch_max_sequence_length, new_input_length + next_batch_max_input_length = max( + next_batch_max_input_length, new_input_length ) # Prefill @@ -506,7 +506,7 @@ class CausalLM(Model): next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias, size=next_batch_size, - max_sequence_length=next_batch_max_sequence_length, + max_input_length=next_batch_max_input_length, padding_right_offset=batch.padding_right_offset - 1, keys_head_dim_last=batch.keys_head_dim_last, ) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 0f7f4df9..0fe5c03f 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -68,17 +68,14 @@ class Seq2SeqLMBatch(Batch): inputs = [] next_token_choosers = [] stopping_criterias = [] - input_lengths = [] decoder_input_ids = [] decoder_input_lengths = [] # Parse batch - max_input_length = 0 padding_right_offset = 0 for r in pb.requests: inputs.append(r.inputs) - input_lengths.append(r.input_length) # Decoder sequence only contains the bos_token decoder_input_ids.append(tokenizer.bos_token_id) decoder_input_lengths.append(1) @@ -87,7 +84,6 @@ class Seq2SeqLMBatch(Batch): r.stopping_parameters, tokenizer ) stopping_criterias.append(stopping_criteria) - max_input_length = max(max_input_length, r.input_length) padding_right_offset = max( padding_right_offset, stopping_criteria.max_new_tokens ) @@ -99,6 +95,10 @@ class Seq2SeqLMBatch(Batch): padding=True, return_token_type_ids=False, ).to(device) + + input_lengths = tokenized_inputs["attention_mask"].sum(1) + max_input_length = input_lengths.max() + # Convert decoder_input_ids to torch tensor of size [batch_size, 1] decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1) @@ -111,12 +111,12 @@ class Seq2SeqLMBatch(Batch): decoder_attention_mask=None, encoder_last_hidden_state=None, past_key_values=None, - input_lengths=input_lengths, + input_lengths=input_lengths.tolist(), decoder_input_lengths=decoder_input_lengths, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=len(pb.requests), - max_input_length=max(input_lengths), + max_input_length=max_input_length.item(), max_decoder_input_length=1, padding_right_offset=padding_right_offset, )