fix(server): use server tokenizer as gt (#128)

This commit is contained in:
OlivierDehaene 2023-03-16 12:12:26 +01:00 committed by GitHub
parent 8ad60b752f
commit b49dbf2d88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 46 additions and 54 deletions

View File

@ -58,12 +58,10 @@ message Request {
uint64 id = 1; uint64 id = 1;
/// The generation context /// The generation context
string inputs = 2; string inputs = 2;
/// The number of tokens inside inputs
uint32 input_length = 3;
/// Next Token Chooser Parameters /// Next Token Chooser Parameters
NextTokenChooserParameters parameters = 4; NextTokenChooserParameters parameters = 3;
/// Stopping Criteria Parameters /// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 5; StoppingCriteriaParameters stopping_parameters = 4;
} }
message Batch { message Batch {

View File

@ -278,7 +278,8 @@ async fn batching_task(
// because a new batch is being computed // because a new batch is being computed
let entry_waiting_span = let entry_waiting_span =
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size); info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
// Add relationship // Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span); entry_waiting_span.follows_from(&span);
// Update entry // Update entry
entry.temp_span = Some(entry_waiting_span); entry.temp_span = Some(entry_waiting_span);
@ -305,7 +306,8 @@ async fn batching_task(
// Create a new span to link the batch back to this entry // Create a new span to link the batch back to this entry
let entry_batch_span = let entry_batch_span =
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size); info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
// Add relationship // Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span); entry_batch_span.follows_from(&next_batch_span);
// Update entry // Update entry
entry.temp_span = Some(entry_batch_span); entry.temp_span = Some(entry_batch_span);

View File

@ -165,7 +165,8 @@ impl State {
// Create a new span to link the batch back to this entry // Create a new span to link the batch back to this entry
let entry_batch_span = let entry_batch_span =
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size); info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
// Add relationship // Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span); entry_batch_span.follows_from(&next_batch_span);
// Update entry // Update entry
entry.temp_span = Some(entry_batch_span); entry.temp_span = Some(entry_batch_span);
@ -173,7 +174,6 @@ impl State {
batch_requests.push(Request { batch_requests.push(Request {
id, id,
inputs: entry.request.inputs.clone(), inputs: entry.request.inputs.clone(),
input_length: entry.request.input_length,
parameters: Some(entry.request.parameters.clone()), parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()), stopping_parameters: Some(entry.request.stopping_parameters.clone()),
}); });
@ -226,7 +226,6 @@ mod tests {
Entry { Entry {
request: ValidGenerateRequest { request: ValidGenerateRequest {
inputs: "".to_string(), inputs: "".to_string(),
input_length: 0,
parameters: NextTokenChooserParameters { parameters: NextTokenChooserParameters {
temperature: 0.0, temperature: 0.0,
top_k: 0, top_k: 0,

View File

@ -322,7 +322,6 @@ fn validate(
Ok(ValidGenerateRequest { Ok(ValidGenerateRequest {
inputs, inputs,
input_length: input_length as u32,
parameters, parameters,
stopping_parameters, stopping_parameters,
}) })
@ -337,7 +336,6 @@ type ValidationRequest = (
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct ValidGenerateRequest { pub(crate) struct ValidGenerateRequest {
pub inputs: String, pub inputs: String,
pub input_length: u32,
pub parameters: NextTokenChooserParameters, pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters, pub stopping_parameters: StoppingCriteriaParameters,
} }

View File

@ -24,7 +24,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="Test", inputs="Test",
input_length=1,
parameters=default_pb_parameters, parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters, stopping_parameters=default_pb_stop_parameters,
) )
@ -77,7 +76,7 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
assert batch.size == default_pb_batch.size assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert batch.max_sequence_length == batch.input_lengths[0] assert batch.max_input_length == batch.input_lengths[0]
def test_batch_concatenate_no_prefill(default_bloom_batch): def test_batch_concatenate_no_prefill(default_bloom_batch):
@ -110,7 +109,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert next_batch.input_ids[0, 0] == 10264 assert next_batch.input_ids[0, 0] == 10264
assert next_batch.input_lengths == [2] assert next_batch.input_lengths == [2]
assert next_batch.max_sequence_length == next_batch.input_lengths[0] assert next_batch.max_input_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None assert next_batch.past_key_values is not None
assert all( assert all(
@ -222,7 +221,7 @@ def test_batch_concatenate(
assert torch.all(next_batch.input_ids == 10264) assert torch.all(next_batch.input_ids == 10264)
assert next_batch.input_lengths == [3, 2, 2] assert next_batch.input_lengths == [3, 2, 2]
assert next_batch.max_sequence_length == 3 assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0] assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests assert next_batch.requests[1:] == next_batch_1.requests

View File

@ -25,7 +25,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="Test", inputs="Test",
input_length=1,
parameters=default_pb_parameters, parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters, stopping_parameters=default_pb_stop_parameters,
) )
@ -74,7 +73,7 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
assert batch.size == default_pb_batch.size assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert batch.max_sequence_length == batch.input_lengths[0] assert batch.max_input_length == batch.input_lengths[0]
def test_batch_concatenate_no_prefill(default_causal_lm_batch): def test_batch_concatenate_no_prefill(default_causal_lm_batch):
@ -107,7 +106,7 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert next_batch.input_ids[0, 0] == 13 assert next_batch.input_ids[0, 0] == 13
assert next_batch.input_lengths == [2] assert next_batch.input_lengths == [2]
assert next_batch.max_sequence_length == next_batch.input_lengths[0] assert next_batch.max_input_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None assert next_batch.past_key_values is not None
assert all( assert all(
@ -220,7 +219,7 @@ def test_batch_concatenate(
assert torch.all(next_batch.input_ids[1:] == 13) assert torch.all(next_batch.input_ids[1:] == 13)
assert next_batch.input_lengths == [3, 2, 2] assert next_batch.input_lengths == [3, 2, 2]
assert next_batch.max_sequence_length == 3 assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0] assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests assert next_batch.requests[1:] == next_batch_1.requests

View File

@ -15,7 +15,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="def", inputs="def",
input_length=1,
parameters=default_pb_parameters, parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters, stopping_parameters=default_pb_stop_parameters,
) )
@ -31,7 +30,6 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>", inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
input_length=5,
parameters=default_pb_parameters, parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters, stopping_parameters=default_pb_stop_parameters,
) )

View File

@ -28,7 +28,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="Test", inputs="Test",
input_length=2,
parameters=default_pb_parameters, parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters, stopping_parameters=default_pb_stop_parameters,
) )

View File

@ -41,7 +41,7 @@ class CausalLMBatch(Batch):
# Metadata used for padding # Metadata used for padding
size: int size: int
max_sequence_length: int max_input_length: int
padding_right_offset: int padding_right_offset: int
# Past metadata # Past metadata
@ -67,17 +67,14 @@ class CausalLMBatch(Batch):
input_lengths = [] input_lengths = []
# Parse batch # Parse batch
max_sequence_length = 0
padding_right_offset = 0 padding_right_offset = 0
for r in pb.requests: for r in pb.requests:
inputs.append(r.inputs) inputs.append(r.inputs)
input_lengths.append(r.input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
max_sequence_length = max(max_sequence_length, r.input_length)
padding_right_offset = max( padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens padding_right_offset, stopping_criteria.max_new_tokens
) )
@ -89,13 +86,16 @@ class CausalLMBatch(Batch):
return_token_type_ids=False, return_token_type_ids=False,
).to(device) ).to(device)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"] input_ids = tokenized_inputs["input_ids"]
# Allocate maximum attention_mask # Allocate maximum attention_mask
attention_mask = input_ids.new_zeros( attention_mask = input_ids.new_zeros(
(pb.size, max_sequence_length + padding_right_offset) (pb.size, max_input_length + padding_right_offset)
) )
# Copy tokenizer attention_mask into fully allocated attention_mask # Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"] attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
@ -109,11 +109,11 @@ class CausalLMBatch(Batch):
position_ids=position_ids, position_ids=position_ids,
past_key_values=None, past_key_values=None,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
input_lengths=input_lengths, input_lengths=input_lengths.tolist(),
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=pb.size, size=pb.size,
max_sequence_length=max_sequence_length, max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
) )
@ -122,11 +122,11 @@ class CausalLMBatch(Batch):
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
# Used for padding # Used for padding
total_batch_size = 0 total_batch_size = 0
max_sequence_length = 0 max_input_length = 0
padding_right_offset = 0 padding_right_offset = 0
for batch in batches: for batch in batches:
total_batch_size += batch.size total_batch_size += batch.size
max_sequence_length = max(max_sequence_length, batch.max_sequence_length) max_input_length = max(max_input_length, batch.max_input_length)
padding_right_offset = max(padding_right_offset, batch.padding_right_offset) padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes # Batch attributes
@ -170,15 +170,15 @@ class CausalLMBatch(Batch):
# Create padded tensor # Create padded tensor
if attention_mask is None: if attention_mask is None:
attention_mask = batch.attention_mask.new_zeros( attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_sequence_length + padding_right_offset), (total_batch_size, max_input_length + padding_right_offset),
) )
# We need to slice the attention mask to remove padding from previous steps # We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space # and to remove unused allocated space
left_offset = max_sequence_length - batch.max_sequence_length left_offset = max_input_length - batch.max_input_length
batch_left_offset = ( batch_left_offset = (
batch.attention_mask.shape[1] batch.attention_mask.shape[1]
- batch.max_sequence_length - batch.max_input_length
- batch.padding_right_offset - batch.padding_right_offset
) )
attention_mask[ attention_mask[
@ -209,7 +209,7 @@ class CausalLMBatch(Batch):
padded_past_values_shape = ( padded_past_values_shape = (
total_batch_size, total_batch_size,
num_heads, num_heads,
max_sequence_length - 1, max_input_length - 1,
head_dim, head_dim,
) )
@ -221,7 +221,7 @@ class CausalLMBatch(Batch):
total_batch_size, total_batch_size,
num_heads, num_heads,
head_dim, head_dim,
max_sequence_length - 1, max_input_length - 1,
) )
# This will run only once per layer # This will run only once per layer
@ -235,20 +235,20 @@ class CausalLMBatch(Batch):
past_key_values[j][0][ past_key_values[j][0][
start_index:end_index, start_index:end_index,
:, :,
-(batch.max_sequence_length - 1) :, -(batch.max_input_length - 1) :,
:, :,
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :] ] = past_keys[:, :, -(batch.max_input_length - 1) :, :]
else: else:
past_key_values[j][0][ past_key_values[j][0][
start_index:end_index, start_index:end_index,
:, :,
:, :,
-(batch.max_sequence_length - 1) :, -(batch.max_input_length - 1) :,
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :] ] = past_keys[:, :, :, -(batch.max_input_length - 1) :]
past_key_values[j][1][ past_key_values[j][1][
start_index:end_index, :, -(batch.max_sequence_length - 1) :, : start_index:end_index, :, -(batch.max_input_length - 1) :, :
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :] ] = past_values[:, :, -(batch.max_input_length - 1) :, :]
start_index += batch.size start_index += batch.size
@ -264,7 +264,7 @@ class CausalLMBatch(Batch):
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=total_batch_size, size=total_batch_size,
max_sequence_length=max_sequence_length, max_input_length=max_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last, keys_head_dim_last=batches[0].keys_head_dim_last,
) )
@ -352,7 +352,7 @@ class CausalLM(Model):
# Metadata # Metadata
next_batch_size = 0 next_batch_size = 0
next_batch_max_sequence_length = 0 next_batch_max_input_length = 0
# Results # Results
generations: List[Generation] = [] generations: List[Generation] = []
@ -420,8 +420,8 @@ class CausalLM(Model):
next_batch_all_input_ids.append(all_input_ids) next_batch_all_input_ids.append(all_input_ids)
next_batch_size += 1 next_batch_size += 1
next_batch_input_lengths.append(new_input_length) next_batch_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max( next_batch_max_input_length = max(
next_batch_max_sequence_length, new_input_length next_batch_max_input_length, new_input_length
) )
# Prefill # Prefill
@ -506,7 +506,7 @@ class CausalLM(Model):
next_token_choosers=next_batch_next_token_choosers, next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias, stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size, size=next_batch_size,
max_sequence_length=next_batch_max_sequence_length, max_input_length=next_batch_max_input_length,
padding_right_offset=batch.padding_right_offset - 1, padding_right_offset=batch.padding_right_offset - 1,
keys_head_dim_last=batch.keys_head_dim_last, keys_head_dim_last=batch.keys_head_dim_last,
) )

View File

@ -68,17 +68,14 @@ class Seq2SeqLMBatch(Batch):
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
input_lengths = []
decoder_input_ids = [] decoder_input_ids = []
decoder_input_lengths = [] decoder_input_lengths = []
# Parse batch # Parse batch
max_input_length = 0
padding_right_offset = 0 padding_right_offset = 0
for r in pb.requests: for r in pb.requests:
inputs.append(r.inputs) inputs.append(r.inputs)
input_lengths.append(r.input_length)
# Decoder sequence only contains the bos_token # Decoder sequence only contains the bos_token
decoder_input_ids.append(tokenizer.bos_token_id) decoder_input_ids.append(tokenizer.bos_token_id)
decoder_input_lengths.append(1) decoder_input_lengths.append(1)
@ -87,7 +84,6 @@ class Seq2SeqLMBatch(Batch):
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
max_input_length = max(max_input_length, r.input_length)
padding_right_offset = max( padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens padding_right_offset, stopping_criteria.max_new_tokens
) )
@ -99,6 +95,10 @@ class Seq2SeqLMBatch(Batch):
padding=True, padding=True,
return_token_type_ids=False, return_token_type_ids=False,
).to(device) ).to(device)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
# Convert decoder_input_ids to torch tensor of size [batch_size, 1] # Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1) decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
@ -111,12 +111,12 @@ class Seq2SeqLMBatch(Batch):
decoder_attention_mask=None, decoder_attention_mask=None,
encoder_last_hidden_state=None, encoder_last_hidden_state=None,
past_key_values=None, past_key_values=None,
input_lengths=input_lengths, input_lengths=input_lengths.tolist(),
decoder_input_lengths=decoder_input_lengths, decoder_input_lengths=decoder_input_lengths,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=len(pb.requests), size=len(pb.requests),
max_input_length=max(input_lengths), max_input_length=max_input_length.item(),
max_decoder_input_length=1, max_decoder_input_length=1,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
) )