fix(server): use server tokenizer as gt (#128)
This commit is contained in:
parent
8ad60b752f
commit
b49dbf2d88
|
@ -58,12 +58,10 @@ message Request {
|
|||
uint64 id = 1;
|
||||
/// The generation context
|
||||
string inputs = 2;
|
||||
/// The number of tokens inside inputs
|
||||
uint32 input_length = 3;
|
||||
/// Next Token Chooser Parameters
|
||||
NextTokenChooserParameters parameters = 4;
|
||||
NextTokenChooserParameters parameters = 3;
|
||||
/// Stopping Criteria Parameters
|
||||
StoppingCriteriaParameters stopping_parameters = 5;
|
||||
StoppingCriteriaParameters stopping_parameters = 4;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
|
|
|
@ -278,7 +278,8 @@ async fn batching_task(
|
|||
// because a new batch is being computed
|
||||
let entry_waiting_span =
|
||||
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
|
||||
// Add relationship
|
||||
// Add relationships
|
||||
span.follows_from(&entry_waiting_span);
|
||||
entry_waiting_span.follows_from(&span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_waiting_span);
|
||||
|
@ -305,7 +306,8 @@ async fn batching_task(
|
|||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span =
|
||||
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
|
||||
// Add relationship
|
||||
// Add relationships
|
||||
next_batch_span.follows_from(&entry_batch_span);
|
||||
entry_batch_span.follows_from(&next_batch_span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
|
|
|
@ -165,7 +165,8 @@ impl State {
|
|||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span =
|
||||
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
|
||||
// Add relationship
|
||||
// Add relationships
|
||||
next_batch_span.follows_from(&entry_batch_span);
|
||||
entry_batch_span.follows_from(&next_batch_span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
|
@ -173,7 +174,6 @@ impl State {
|
|||
batch_requests.push(Request {
|
||||
id,
|
||||
inputs: entry.request.inputs.clone(),
|
||||
input_length: entry.request.input_length,
|
||||
parameters: Some(entry.request.parameters.clone()),
|
||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||
});
|
||||
|
@ -226,7 +226,6 @@ mod tests {
|
|||
Entry {
|
||||
request: ValidGenerateRequest {
|
||||
inputs: "".to_string(),
|
||||
input_length: 0,
|
||||
parameters: NextTokenChooserParameters {
|
||||
temperature: 0.0,
|
||||
top_k: 0,
|
||||
|
|
|
@ -322,7 +322,6 @@ fn validate(
|
|||
|
||||
Ok(ValidGenerateRequest {
|
||||
inputs,
|
||||
input_length: input_length as u32,
|
||||
parameters,
|
||||
stopping_parameters,
|
||||
})
|
||||
|
@ -337,7 +336,6 @@ type ValidationRequest = (
|
|||
#[derive(Debug)]
|
||||
pub(crate) struct ValidGenerateRequest {
|
||||
pub inputs: String,
|
||||
pub input_length: u32,
|
||||
pub parameters: NextTokenChooserParameters,
|
||||
pub stopping_parameters: StoppingCriteriaParameters,
|
||||
}
|
||||
|
|
|
@ -24,7 +24,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="Test",
|
||||
input_length=1,
|
||||
parameters=default_pb_parameters,
|
||||
stopping_parameters=default_pb_stop_parameters,
|
||||
)
|
||||
|
@ -77,7 +76,7 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
|
|||
assert batch.size == default_pb_batch.size
|
||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
|
||||
|
||||
assert batch.max_sequence_length == batch.input_lengths[0]
|
||||
assert batch.max_input_length == batch.input_lengths[0]
|
||||
|
||||
|
||||
def test_batch_concatenate_no_prefill(default_bloom_batch):
|
||||
|
@ -110,7 +109,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
|||
assert next_batch.input_ids[0, 0] == 10264
|
||||
|
||||
assert next_batch.input_lengths == [2]
|
||||
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
|
||||
assert next_batch.max_input_length == next_batch.input_lengths[0]
|
||||
|
||||
assert next_batch.past_key_values is not None
|
||||
assert all(
|
||||
|
@ -222,7 +221,7 @@ def test_batch_concatenate(
|
|||
assert torch.all(next_batch.input_ids == 10264)
|
||||
|
||||
assert next_batch.input_lengths == [3, 2, 2]
|
||||
assert next_batch.max_sequence_length == 3
|
||||
assert next_batch.max_input_length == 3
|
||||
|
||||
assert next_batch.requests[0] == next_batch_0.requests[0]
|
||||
assert next_batch.requests[1:] == next_batch_1.requests
|
||||
|
|
|
@ -25,7 +25,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="Test",
|
||||
input_length=1,
|
||||
parameters=default_pb_parameters,
|
||||
stopping_parameters=default_pb_stop_parameters,
|
||||
)
|
||||
|
@ -74,7 +73,7 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
|||
assert batch.size == default_pb_batch.size
|
||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
|
||||
|
||||
assert batch.max_sequence_length == batch.input_lengths[0]
|
||||
assert batch.max_input_length == batch.input_lengths[0]
|
||||
|
||||
|
||||
def test_batch_concatenate_no_prefill(default_causal_lm_batch):
|
||||
|
@ -107,7 +106,7 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
|||
assert next_batch.input_ids[0, 0] == 13
|
||||
|
||||
assert next_batch.input_lengths == [2]
|
||||
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
|
||||
assert next_batch.max_input_length == next_batch.input_lengths[0]
|
||||
|
||||
assert next_batch.past_key_values is not None
|
||||
assert all(
|
||||
|
@ -220,7 +219,7 @@ def test_batch_concatenate(
|
|||
assert torch.all(next_batch.input_ids[1:] == 13)
|
||||
|
||||
assert next_batch.input_lengths == [3, 2, 2]
|
||||
assert next_batch.max_sequence_length == 3
|
||||
assert next_batch.max_input_length == 3
|
||||
|
||||
assert next_batch.requests[0] == next_batch_0.requests[0]
|
||||
assert next_batch.requests[1:] == next_batch_1.requests
|
||||
|
|
|
@ -15,7 +15,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="def",
|
||||
input_length=1,
|
||||
parameters=default_pb_parameters,
|
||||
stopping_parameters=default_pb_stop_parameters,
|
||||
)
|
||||
|
@ -31,7 +30,6 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
||||
input_length=5,
|
||||
parameters=default_pb_parameters,
|
||||
stopping_parameters=default_pb_stop_parameters,
|
||||
)
|
||||
|
|
|
@ -28,7 +28,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="Test",
|
||||
input_length=2,
|
||||
parameters=default_pb_parameters,
|
||||
stopping_parameters=default_pb_stop_parameters,
|
||||
)
|
||||
|
|
|
@ -41,7 +41,7 @@ class CausalLMBatch(Batch):
|
|||
|
||||
# Metadata used for padding
|
||||
size: int
|
||||
max_sequence_length: int
|
||||
max_input_length: int
|
||||
padding_right_offset: int
|
||||
|
||||
# Past metadata
|
||||
|
@ -67,17 +67,14 @@ class CausalLMBatch(Batch):
|
|||
input_lengths = []
|
||||
|
||||
# Parse batch
|
||||
max_sequence_length = 0
|
||||
padding_right_offset = 0
|
||||
for r in pb.requests:
|
||||
inputs.append(r.inputs)
|
||||
input_lengths.append(r.input_length)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
max_sequence_length = max(max_sequence_length, r.input_length)
|
||||
padding_right_offset = max(
|
||||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
@ -89,13 +86,16 @@ class CausalLMBatch(Batch):
|
|||
return_token_type_ids=False,
|
||||
).to(device)
|
||||
|
||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||
max_input_length = input_lengths.max()
|
||||
|
||||
input_ids = tokenized_inputs["input_ids"]
|
||||
# Allocate maximum attention_mask
|
||||
attention_mask = input_ids.new_zeros(
|
||||
(pb.size, max_sequence_length + padding_right_offset)
|
||||
(pb.size, max_input_length + padding_right_offset)
|
||||
)
|
||||
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]
|
||||
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
|
||||
|
||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||
|
@ -109,11 +109,11 @@ class CausalLMBatch(Batch):
|
|||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
all_input_ids=all_input_ids,
|
||||
input_lengths=input_lengths,
|
||||
input_lengths=input_lengths.tolist(),
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=pb.size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
max_input_length=max_input_length.item(),
|
||||
padding_right_offset=padding_right_offset,
|
||||
)
|
||||
|
||||
|
@ -122,11 +122,11 @@ class CausalLMBatch(Batch):
|
|||
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
|
||||
# Used for padding
|
||||
total_batch_size = 0
|
||||
max_sequence_length = 0
|
||||
max_input_length = 0
|
||||
padding_right_offset = 0
|
||||
for batch in batches:
|
||||
total_batch_size += batch.size
|
||||
max_sequence_length = max(max_sequence_length, batch.max_sequence_length)
|
||||
max_input_length = max(max_input_length, batch.max_input_length)
|
||||
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
||||
|
||||
# Batch attributes
|
||||
|
@ -170,15 +170,15 @@ class CausalLMBatch(Batch):
|
|||
# Create padded tensor
|
||||
if attention_mask is None:
|
||||
attention_mask = batch.attention_mask.new_zeros(
|
||||
(total_batch_size, max_sequence_length + padding_right_offset),
|
||||
(total_batch_size, max_input_length + padding_right_offset),
|
||||
)
|
||||
|
||||
# We need to slice the attention mask to remove padding from previous steps
|
||||
# and to remove unused allocated space
|
||||
left_offset = max_sequence_length - batch.max_sequence_length
|
||||
left_offset = max_input_length - batch.max_input_length
|
||||
batch_left_offset = (
|
||||
batch.attention_mask.shape[1]
|
||||
- batch.max_sequence_length
|
||||
- batch.max_input_length
|
||||
- batch.padding_right_offset
|
||||
)
|
||||
attention_mask[
|
||||
|
@ -209,7 +209,7 @@ class CausalLMBatch(Batch):
|
|||
padded_past_values_shape = (
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
max_sequence_length - 1,
|
||||
max_input_length - 1,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
|
@ -221,7 +221,7 @@ class CausalLMBatch(Batch):
|
|||
total_batch_size,
|
||||
num_heads,
|
||||
head_dim,
|
||||
max_sequence_length - 1,
|
||||
max_input_length - 1,
|
||||
)
|
||||
|
||||
# This will run only once per layer
|
||||
|
@ -235,20 +235,20 @@ class CausalLMBatch(Batch):
|
|||
past_key_values[j][0][
|
||||
start_index:end_index,
|
||||
:,
|
||||
-(batch.max_sequence_length - 1) :,
|
||||
-(batch.max_input_length - 1) :,
|
||||
:,
|
||||
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||
] = past_keys[:, :, -(batch.max_input_length - 1) :, :]
|
||||
else:
|
||||
past_key_values[j][0][
|
||||
start_index:end_index,
|
||||
:,
|
||||
:,
|
||||
-(batch.max_sequence_length - 1) :,
|
||||
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
|
||||
-(batch.max_input_length - 1) :,
|
||||
] = past_keys[:, :, :, -(batch.max_input_length - 1) :]
|
||||
|
||||
past_key_values[j][1][
|
||||
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
|
||||
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||
start_index:end_index, :, -(batch.max_input_length - 1) :, :
|
||||
] = past_values[:, :, -(batch.max_input_length - 1) :, :]
|
||||
|
||||
start_index += batch.size
|
||||
|
||||
|
@ -264,7 +264,7 @@ class CausalLMBatch(Batch):
|
|||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=total_batch_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
max_input_length=max_input_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||
)
|
||||
|
@ -352,7 +352,7 @@ class CausalLM(Model):
|
|||
|
||||
# Metadata
|
||||
next_batch_size = 0
|
||||
next_batch_max_sequence_length = 0
|
||||
next_batch_max_input_length = 0
|
||||
|
||||
# Results
|
||||
generations: List[Generation] = []
|
||||
|
@ -420,8 +420,8 @@ class CausalLM(Model):
|
|||
next_batch_all_input_ids.append(all_input_ids)
|
||||
next_batch_size += 1
|
||||
next_batch_input_lengths.append(new_input_length)
|
||||
next_batch_max_sequence_length = max(
|
||||
next_batch_max_sequence_length, new_input_length
|
||||
next_batch_max_input_length = max(
|
||||
next_batch_max_input_length, new_input_length
|
||||
)
|
||||
|
||||
# Prefill
|
||||
|
@ -506,7 +506,7 @@ class CausalLM(Model):
|
|||
next_token_choosers=next_batch_next_token_choosers,
|
||||
stopping_criterias=next_batch_stopping_criterias,
|
||||
size=next_batch_size,
|
||||
max_sequence_length=next_batch_max_sequence_length,
|
||||
max_input_length=next_batch_max_input_length,
|
||||
padding_right_offset=batch.padding_right_offset - 1,
|
||||
keys_head_dim_last=batch.keys_head_dim_last,
|
||||
)
|
||||
|
|
|
@ -68,17 +68,14 @@ class Seq2SeqLMBatch(Batch):
|
|||
inputs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
input_lengths = []
|
||||
|
||||
decoder_input_ids = []
|
||||
decoder_input_lengths = []
|
||||
|
||||
# Parse batch
|
||||
max_input_length = 0
|
||||
padding_right_offset = 0
|
||||
for r in pb.requests:
|
||||
inputs.append(r.inputs)
|
||||
input_lengths.append(r.input_length)
|
||||
# Decoder sequence only contains the bos_token
|
||||
decoder_input_ids.append(tokenizer.bos_token_id)
|
||||
decoder_input_lengths.append(1)
|
||||
|
@ -87,7 +84,6 @@ class Seq2SeqLMBatch(Batch):
|
|||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
max_input_length = max(max_input_length, r.input_length)
|
||||
padding_right_offset = max(
|
||||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
@ -99,6 +95,10 @@ class Seq2SeqLMBatch(Batch):
|
|||
padding=True,
|
||||
return_token_type_ids=False,
|
||||
).to(device)
|
||||
|
||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||
max_input_length = input_lengths.max()
|
||||
|
||||
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
||||
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
|
||||
|
||||
|
@ -111,12 +111,12 @@ class Seq2SeqLMBatch(Batch):
|
|||
decoder_attention_mask=None,
|
||||
encoder_last_hidden_state=None,
|
||||
past_key_values=None,
|
||||
input_lengths=input_lengths,
|
||||
input_lengths=input_lengths.tolist(),
|
||||
decoder_input_lengths=decoder_input_lengths,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=len(pb.requests),
|
||||
max_input_length=max(input_lengths),
|
||||
max_input_length=max_input_length.item(),
|
||||
max_decoder_input_length=1,
|
||||
padding_right_offset=padding_right_offset,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue