fix(server): use server tokenizer as gt (#128)

This commit is contained in:
OlivierDehaene 2023-03-16 12:12:26 +01:00 committed by GitHub
parent 8ad60b752f
commit b49dbf2d88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 46 additions and 54 deletions

View File

@ -58,12 +58,10 @@ message Request {
uint64 id = 1;
/// The generation context
string inputs = 2;
/// The number of tokens inside inputs
uint32 input_length = 3;
/// Next Token Chooser Parameters
NextTokenChooserParameters parameters = 4;
NextTokenChooserParameters parameters = 3;
/// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 5;
StoppingCriteriaParameters stopping_parameters = 4;
}
message Batch {

View File

@ -278,7 +278,8 @@ async fn batching_task(
// because a new batch is being computed
let entry_waiting_span =
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
// Add relationship
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
@ -305,7 +306,8 @@ async fn batching_task(
// Create a new span to link the batch back to this entry
let entry_batch_span =
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
// Add relationship
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);

View File

@ -165,7 +165,8 @@ impl State {
// Create a new span to link the batch back to this entry
let entry_batch_span =
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
// Add relationship
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
@ -173,7 +174,6 @@ impl State {
batch_requests.push(Request {
id,
inputs: entry.request.inputs.clone(),
input_length: entry.request.input_length,
parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
});
@ -226,7 +226,6 @@ mod tests {
Entry {
request: ValidGenerateRequest {
inputs: "".to_string(),
input_length: 0,
parameters: NextTokenChooserParameters {
temperature: 0.0,
top_k: 0,

View File

@ -322,7 +322,6 @@ fn validate(
Ok(ValidGenerateRequest {
inputs,
input_length: input_length as u32,
parameters,
stopping_parameters,
})
@ -337,7 +336,6 @@ type ValidationRequest = (
#[derive(Debug)]
pub(crate) struct ValidGenerateRequest {
pub inputs: String,
pub input_length: u32,
pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters,
}

View File

@ -24,7 +24,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_length=1,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@ -77,7 +76,7 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert batch.max_sequence_length == batch.input_lengths[0]
assert batch.max_input_length == batch.input_lengths[0]
def test_batch_concatenate_no_prefill(default_bloom_batch):
@ -110,7 +109,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert next_batch.input_ids[0, 0] == 10264
assert next_batch.input_lengths == [2]
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
assert next_batch.max_input_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None
assert all(
@ -222,7 +221,7 @@ def test_batch_concatenate(
assert torch.all(next_batch.input_ids == 10264)
assert next_batch.input_lengths == [3, 2, 2]
assert next_batch.max_sequence_length == 3
assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests

View File

@ -25,7 +25,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_length=1,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@ -74,7 +73,7 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert batch.max_sequence_length == batch.input_lengths[0]
assert batch.max_input_length == batch.input_lengths[0]
def test_batch_concatenate_no_prefill(default_causal_lm_batch):
@ -107,7 +106,7 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert next_batch.input_ids[0, 0] == 13
assert next_batch.input_lengths == [2]
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
assert next_batch.max_input_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None
assert all(
@ -220,7 +219,7 @@ def test_batch_concatenate(
assert torch.all(next_batch.input_ids[1:] == 13)
assert next_batch.input_lengths == [3, 2, 2]
assert next_batch.max_sequence_length == 3
assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests

View File

@ -15,7 +15,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="def",
input_length=1,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@ -31,7 +30,6 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
input_length=5,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)

View File

@ -28,7 +28,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_length=2,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)

View File

@ -41,7 +41,7 @@ class CausalLMBatch(Batch):
# Metadata used for padding
size: int
max_sequence_length: int
max_input_length: int
padding_right_offset: int
# Past metadata
@ -67,17 +67,14 @@ class CausalLMBatch(Batch):
input_lengths = []
# Parse batch
max_sequence_length = 0
padding_right_offset = 0
for r in pb.requests:
inputs.append(r.inputs)
input_lengths.append(r.input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
max_sequence_length = max(max_sequence_length, r.input_length)
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
@ -89,13 +86,16 @@ class CausalLMBatch(Batch):
return_token_type_ids=False,
).to(device)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"]
# Allocate maximum attention_mask
attention_mask = input_ids.new_zeros(
(pb.size, max_sequence_length + padding_right_offset)
(pb.size, max_input_length + padding_right_offset)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
@ -109,11 +109,11 @@ class CausalLMBatch(Batch):
position_ids=position_ids,
past_key_values=None,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
input_lengths=input_lengths.tolist(),
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=pb.size,
max_sequence_length=max_sequence_length,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
)
@ -122,11 +122,11 @@ class CausalLMBatch(Batch):
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
# Used for padding
total_batch_size = 0
max_sequence_length = 0
max_input_length = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += batch.size
max_sequence_length = max(max_sequence_length, batch.max_sequence_length)
max_input_length = max(max_input_length, batch.max_input_length)
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes
@ -170,15 +170,15 @@ class CausalLMBatch(Batch):
# Create padded tensor
if attention_mask is None:
attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_sequence_length + padding_right_offset),
(total_batch_size, max_input_length + padding_right_offset),
)
# We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space
left_offset = max_sequence_length - batch.max_sequence_length
left_offset = max_input_length - batch.max_input_length
batch_left_offset = (
batch.attention_mask.shape[1]
- batch.max_sequence_length
- batch.max_input_length
- batch.padding_right_offset
)
attention_mask[
@ -209,7 +209,7 @@ class CausalLMBatch(Batch):
padded_past_values_shape = (
total_batch_size,
num_heads,
max_sequence_length - 1,
max_input_length - 1,
head_dim,
)
@ -221,7 +221,7 @@ class CausalLMBatch(Batch):
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
max_input_length - 1,
)
# This will run only once per layer
@ -235,20 +235,20 @@ class CausalLMBatch(Batch):
past_key_values[j][0][
start_index:end_index,
:,
-(batch.max_sequence_length - 1) :,
-(batch.max_input_length - 1) :,
:,
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
] = past_keys[:, :, -(batch.max_input_length - 1) :, :]
else:
past_key_values[j][0][
start_index:end_index,
:,
:,
-(batch.max_sequence_length - 1) :,
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
-(batch.max_input_length - 1) :,
] = past_keys[:, :, :, -(batch.max_input_length - 1) :]
past_key_values[j][1][
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
start_index:end_index, :, -(batch.max_input_length - 1) :, :
] = past_values[:, :, -(batch.max_input_length - 1) :, :]
start_index += batch.size
@ -264,7 +264,7 @@ class CausalLMBatch(Batch):
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
max_sequence_length=max_sequence_length,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
)
@ -352,7 +352,7 @@ class CausalLM(Model):
# Metadata
next_batch_size = 0
next_batch_max_sequence_length = 0
next_batch_max_input_length = 0
# Results
generations: List[Generation] = []
@ -420,8 +420,8 @@ class CausalLM(Model):
next_batch_all_input_ids.append(all_input_ids)
next_batch_size += 1
next_batch_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length
next_batch_max_input_length = max(
next_batch_max_input_length, new_input_length
)
# Prefill
@ -506,7 +506,7 @@ class CausalLM(Model):
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_sequence_length=next_batch_max_sequence_length,
max_input_length=next_batch_max_input_length,
padding_right_offset=batch.padding_right_offset - 1,
keys_head_dim_last=batch.keys_head_dim_last,
)

View File

@ -68,17 +68,14 @@ class Seq2SeqLMBatch(Batch):
inputs = []
next_token_choosers = []
stopping_criterias = []
input_lengths = []
decoder_input_ids = []
decoder_input_lengths = []
# Parse batch
max_input_length = 0
padding_right_offset = 0
for r in pb.requests:
inputs.append(r.inputs)
input_lengths.append(r.input_length)
# Decoder sequence only contains the bos_token
decoder_input_ids.append(tokenizer.bos_token_id)
decoder_input_lengths.append(1)
@ -87,7 +84,6 @@ class Seq2SeqLMBatch(Batch):
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
max_input_length = max(max_input_length, r.input_length)
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
@ -99,6 +95,10 @@ class Seq2SeqLMBatch(Batch):
padding=True,
return_token_type_ids=False,
).to(device)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
@ -111,12 +111,12 @@ class Seq2SeqLMBatch(Batch):
decoder_attention_mask=None,
encoder_last_hidden_state=None,
past_key_values=None,
input_lengths=input_lengths,
input_lengths=input_lengths.tolist(),
decoder_input_lengths=decoder_input_lengths,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=len(pb.requests),
max_input_length=max(input_lengths),
max_input_length=max_input_length.item(),
max_decoder_input_length=1,
padding_right_offset=padding_right_offset,
)