fix(server): use server tokenizer as gt (#128)
This commit is contained in:
parent
8ad60b752f
commit
b49dbf2d88
|
@ -58,12 +58,10 @@ message Request {
|
||||||
uint64 id = 1;
|
uint64 id = 1;
|
||||||
/// The generation context
|
/// The generation context
|
||||||
string inputs = 2;
|
string inputs = 2;
|
||||||
/// The number of tokens inside inputs
|
|
||||||
uint32 input_length = 3;
|
|
||||||
/// Next Token Chooser Parameters
|
/// Next Token Chooser Parameters
|
||||||
NextTokenChooserParameters parameters = 4;
|
NextTokenChooserParameters parameters = 3;
|
||||||
/// Stopping Criteria Parameters
|
/// Stopping Criteria Parameters
|
||||||
StoppingCriteriaParameters stopping_parameters = 5;
|
StoppingCriteriaParameters stopping_parameters = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
|
|
|
@ -278,7 +278,8 @@ async fn batching_task(
|
||||||
// because a new batch is being computed
|
// because a new batch is being computed
|
||||||
let entry_waiting_span =
|
let entry_waiting_span =
|
||||||
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
|
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
|
||||||
// Add relationship
|
// Add relationships
|
||||||
|
span.follows_from(&entry_waiting_span);
|
||||||
entry_waiting_span.follows_from(&span);
|
entry_waiting_span.follows_from(&span);
|
||||||
// Update entry
|
// Update entry
|
||||||
entry.temp_span = Some(entry_waiting_span);
|
entry.temp_span = Some(entry_waiting_span);
|
||||||
|
@ -305,7 +306,8 @@ async fn batching_task(
|
||||||
// Create a new span to link the batch back to this entry
|
// Create a new span to link the batch back to this entry
|
||||||
let entry_batch_span =
|
let entry_batch_span =
|
||||||
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
|
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
|
||||||
// Add relationship
|
// Add relationships
|
||||||
|
next_batch_span.follows_from(&entry_batch_span);
|
||||||
entry_batch_span.follows_from(&next_batch_span);
|
entry_batch_span.follows_from(&next_batch_span);
|
||||||
// Update entry
|
// Update entry
|
||||||
entry.temp_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
|
|
|
@ -165,7 +165,8 @@ impl State {
|
||||||
// Create a new span to link the batch back to this entry
|
// Create a new span to link the batch back to this entry
|
||||||
let entry_batch_span =
|
let entry_batch_span =
|
||||||
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
|
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
|
||||||
// Add relationship
|
// Add relationships
|
||||||
|
next_batch_span.follows_from(&entry_batch_span);
|
||||||
entry_batch_span.follows_from(&next_batch_span);
|
entry_batch_span.follows_from(&next_batch_span);
|
||||||
// Update entry
|
// Update entry
|
||||||
entry.temp_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
|
@ -173,7 +174,6 @@ impl State {
|
||||||
batch_requests.push(Request {
|
batch_requests.push(Request {
|
||||||
id,
|
id,
|
||||||
inputs: entry.request.inputs.clone(),
|
inputs: entry.request.inputs.clone(),
|
||||||
input_length: entry.request.input_length,
|
|
||||||
parameters: Some(entry.request.parameters.clone()),
|
parameters: Some(entry.request.parameters.clone()),
|
||||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||||
});
|
});
|
||||||
|
@ -226,7 +226,6 @@ mod tests {
|
||||||
Entry {
|
Entry {
|
||||||
request: ValidGenerateRequest {
|
request: ValidGenerateRequest {
|
||||||
inputs: "".to_string(),
|
inputs: "".to_string(),
|
||||||
input_length: 0,
|
|
||||||
parameters: NextTokenChooserParameters {
|
parameters: NextTokenChooserParameters {
|
||||||
temperature: 0.0,
|
temperature: 0.0,
|
||||||
top_k: 0,
|
top_k: 0,
|
||||||
|
|
|
@ -322,7 +322,6 @@ fn validate(
|
||||||
|
|
||||||
Ok(ValidGenerateRequest {
|
Ok(ValidGenerateRequest {
|
||||||
inputs,
|
inputs,
|
||||||
input_length: input_length as u32,
|
|
||||||
parameters,
|
parameters,
|
||||||
stopping_parameters,
|
stopping_parameters,
|
||||||
})
|
})
|
||||||
|
@ -337,7 +336,6 @@ type ValidationRequest = (
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct ValidGenerateRequest {
|
pub(crate) struct ValidGenerateRequest {
|
||||||
pub inputs: String,
|
pub inputs: String,
|
||||||
pub input_length: u32,
|
|
||||||
pub parameters: NextTokenChooserParameters,
|
pub parameters: NextTokenChooserParameters,
|
||||||
pub stopping_parameters: StoppingCriteriaParameters,
|
pub stopping_parameters: StoppingCriteriaParameters,
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
input_length=1,
|
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
)
|
)
|
||||||
|
@ -77,7 +76,7 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
|
||||||
assert batch.size == default_pb_batch.size
|
assert batch.size == default_pb_batch.size
|
||||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
|
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
|
||||||
|
|
||||||
assert batch.max_sequence_length == batch.input_lengths[0]
|
assert batch.max_input_length == batch.input_lengths[0]
|
||||||
|
|
||||||
|
|
||||||
def test_batch_concatenate_no_prefill(default_bloom_batch):
|
def test_batch_concatenate_no_prefill(default_bloom_batch):
|
||||||
|
@ -110,7 +109,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
||||||
assert next_batch.input_ids[0, 0] == 10264
|
assert next_batch.input_ids[0, 0] == 10264
|
||||||
|
|
||||||
assert next_batch.input_lengths == [2]
|
assert next_batch.input_lengths == [2]
|
||||||
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
|
assert next_batch.max_input_length == next_batch.input_lengths[0]
|
||||||
|
|
||||||
assert next_batch.past_key_values is not None
|
assert next_batch.past_key_values is not None
|
||||||
assert all(
|
assert all(
|
||||||
|
@ -222,7 +221,7 @@ def test_batch_concatenate(
|
||||||
assert torch.all(next_batch.input_ids == 10264)
|
assert torch.all(next_batch.input_ids == 10264)
|
||||||
|
|
||||||
assert next_batch.input_lengths == [3, 2, 2]
|
assert next_batch.input_lengths == [3, 2, 2]
|
||||||
assert next_batch.max_sequence_length == 3
|
assert next_batch.max_input_length == 3
|
||||||
|
|
||||||
assert next_batch.requests[0] == next_batch_0.requests[0]
|
assert next_batch.requests[0] == next_batch_0.requests[0]
|
||||||
assert next_batch.requests[1:] == next_batch_1.requests
|
assert next_batch.requests[1:] == next_batch_1.requests
|
||||||
|
|
|
@ -25,7 +25,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
input_length=1,
|
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
)
|
)
|
||||||
|
@ -74,7 +73,7 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
||||||
assert batch.size == default_pb_batch.size
|
assert batch.size == default_pb_batch.size
|
||||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
|
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
|
||||||
|
|
||||||
assert batch.max_sequence_length == batch.input_lengths[0]
|
assert batch.max_input_length == batch.input_lengths[0]
|
||||||
|
|
||||||
|
|
||||||
def test_batch_concatenate_no_prefill(default_causal_lm_batch):
|
def test_batch_concatenate_no_prefill(default_causal_lm_batch):
|
||||||
|
@ -107,7 +106,7 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
||||||
assert next_batch.input_ids[0, 0] == 13
|
assert next_batch.input_ids[0, 0] == 13
|
||||||
|
|
||||||
assert next_batch.input_lengths == [2]
|
assert next_batch.input_lengths == [2]
|
||||||
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
|
assert next_batch.max_input_length == next_batch.input_lengths[0]
|
||||||
|
|
||||||
assert next_batch.past_key_values is not None
|
assert next_batch.past_key_values is not None
|
||||||
assert all(
|
assert all(
|
||||||
|
@ -220,7 +219,7 @@ def test_batch_concatenate(
|
||||||
assert torch.all(next_batch.input_ids[1:] == 13)
|
assert torch.all(next_batch.input_ids[1:] == 13)
|
||||||
|
|
||||||
assert next_batch.input_lengths == [3, 2, 2]
|
assert next_batch.input_lengths == [3, 2, 2]
|
||||||
assert next_batch.max_sequence_length == 3
|
assert next_batch.max_input_length == 3
|
||||||
|
|
||||||
assert next_batch.requests[0] == next_batch_0.requests[0]
|
assert next_batch.requests[0] == next_batch_0.requests[0]
|
||||||
assert next_batch.requests[1:] == next_batch_1.requests
|
assert next_batch.requests[1:] == next_batch_1.requests
|
||||||
|
|
|
@ -15,7 +15,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="def",
|
inputs="def",
|
||||||
input_length=1,
|
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
)
|
)
|
||||||
|
@ -31,7 +30,6 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
||||||
input_length=5,
|
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
)
|
)
|
||||||
|
|
|
@ -28,7 +28,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
input_length=2,
|
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
)
|
)
|
||||||
|
|
|
@ -41,7 +41,7 @@ class CausalLMBatch(Batch):
|
||||||
|
|
||||||
# Metadata used for padding
|
# Metadata used for padding
|
||||||
size: int
|
size: int
|
||||||
max_sequence_length: int
|
max_input_length: int
|
||||||
padding_right_offset: int
|
padding_right_offset: int
|
||||||
|
|
||||||
# Past metadata
|
# Past metadata
|
||||||
|
@ -67,17 +67,14 @@ class CausalLMBatch(Batch):
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
max_sequence_length = 0
|
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
input_lengths.append(r.input_length)
|
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
max_sequence_length = max(max_sequence_length, r.input_length)
|
|
||||||
padding_right_offset = max(
|
padding_right_offset = max(
|
||||||
padding_right_offset, stopping_criteria.max_new_tokens
|
padding_right_offset, stopping_criteria.max_new_tokens
|
||||||
)
|
)
|
||||||
|
@ -89,13 +86,16 @@ class CausalLMBatch(Batch):
|
||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
|
max_input_length = input_lengths.max()
|
||||||
|
|
||||||
input_ids = tokenized_inputs["input_ids"]
|
input_ids = tokenized_inputs["input_ids"]
|
||||||
# Allocate maximum attention_mask
|
# Allocate maximum attention_mask
|
||||||
attention_mask = input_ids.new_zeros(
|
attention_mask = input_ids.new_zeros(
|
||||||
(pb.size, max_sequence_length + padding_right_offset)
|
(pb.size, max_input_length + padding_right_offset)
|
||||||
)
|
)
|
||||||
# Copy tokenizer attention_mask into fully allocated attention_mask
|
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||||
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]
|
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
|
||||||
|
|
||||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||||
|
@ -109,11 +109,11 @@ class CausalLMBatch(Batch):
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths.tolist(),
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=pb.size,
|
size=pb.size,
|
||||||
max_sequence_length=max_sequence_length,
|
max_input_length=max_input_length.item(),
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -122,11 +122,11 @@ class CausalLMBatch(Batch):
|
||||||
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
|
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
|
||||||
# Used for padding
|
# Used for padding
|
||||||
total_batch_size = 0
|
total_batch_size = 0
|
||||||
max_sequence_length = 0
|
max_input_length = 0
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
total_batch_size += batch.size
|
total_batch_size += batch.size
|
||||||
max_sequence_length = max(max_sequence_length, batch.max_sequence_length)
|
max_input_length = max(max_input_length, batch.max_input_length)
|
||||||
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
||||||
|
|
||||||
# Batch attributes
|
# Batch attributes
|
||||||
|
@ -170,15 +170,15 @@ class CausalLMBatch(Batch):
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = batch.attention_mask.new_zeros(
|
attention_mask = batch.attention_mask.new_zeros(
|
||||||
(total_batch_size, max_sequence_length + padding_right_offset),
|
(total_batch_size, max_input_length + padding_right_offset),
|
||||||
)
|
)
|
||||||
|
|
||||||
# We need to slice the attention mask to remove padding from previous steps
|
# We need to slice the attention mask to remove padding from previous steps
|
||||||
# and to remove unused allocated space
|
# and to remove unused allocated space
|
||||||
left_offset = max_sequence_length - batch.max_sequence_length
|
left_offset = max_input_length - batch.max_input_length
|
||||||
batch_left_offset = (
|
batch_left_offset = (
|
||||||
batch.attention_mask.shape[1]
|
batch.attention_mask.shape[1]
|
||||||
- batch.max_sequence_length
|
- batch.max_input_length
|
||||||
- batch.padding_right_offset
|
- batch.padding_right_offset
|
||||||
)
|
)
|
||||||
attention_mask[
|
attention_mask[
|
||||||
|
@ -209,7 +209,7 @@ class CausalLMBatch(Batch):
|
||||||
padded_past_values_shape = (
|
padded_past_values_shape = (
|
||||||
total_batch_size,
|
total_batch_size,
|
||||||
num_heads,
|
num_heads,
|
||||||
max_sequence_length - 1,
|
max_input_length - 1,
|
||||||
head_dim,
|
head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -221,7 +221,7 @@ class CausalLMBatch(Batch):
|
||||||
total_batch_size,
|
total_batch_size,
|
||||||
num_heads,
|
num_heads,
|
||||||
head_dim,
|
head_dim,
|
||||||
max_sequence_length - 1,
|
max_input_length - 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# This will run only once per layer
|
# This will run only once per layer
|
||||||
|
@ -235,20 +235,20 @@ class CausalLMBatch(Batch):
|
||||||
past_key_values[j][0][
|
past_key_values[j][0][
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
:,
|
:,
|
||||||
-(batch.max_sequence_length - 1) :,
|
-(batch.max_input_length - 1) :,
|
||||||
:,
|
:,
|
||||||
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
|
] = past_keys[:, :, -(batch.max_input_length - 1) :, :]
|
||||||
else:
|
else:
|
||||||
past_key_values[j][0][
|
past_key_values[j][0][
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
:,
|
:,
|
||||||
:,
|
:,
|
||||||
-(batch.max_sequence_length - 1) :,
|
-(batch.max_input_length - 1) :,
|
||||||
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
|
] = past_keys[:, :, :, -(batch.max_input_length - 1) :]
|
||||||
|
|
||||||
past_key_values[j][1][
|
past_key_values[j][1][
|
||||||
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
|
start_index:end_index, :, -(batch.max_input_length - 1) :, :
|
||||||
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
|
] = past_values[:, :, -(batch.max_input_length - 1) :, :]
|
||||||
|
|
||||||
start_index += batch.size
|
start_index += batch.size
|
||||||
|
|
||||||
|
@ -264,7 +264,7 @@ class CausalLMBatch(Batch):
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=total_batch_size,
|
size=total_batch_size,
|
||||||
max_sequence_length=max_sequence_length,
|
max_input_length=max_input_length,
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||||
)
|
)
|
||||||
|
@ -352,7 +352,7 @@ class CausalLM(Model):
|
||||||
|
|
||||||
# Metadata
|
# Metadata
|
||||||
next_batch_size = 0
|
next_batch_size = 0
|
||||||
next_batch_max_sequence_length = 0
|
next_batch_max_input_length = 0
|
||||||
|
|
||||||
# Results
|
# Results
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
|
@ -420,8 +420,8 @@ class CausalLM(Model):
|
||||||
next_batch_all_input_ids.append(all_input_ids)
|
next_batch_all_input_ids.append(all_input_ids)
|
||||||
next_batch_size += 1
|
next_batch_size += 1
|
||||||
next_batch_input_lengths.append(new_input_length)
|
next_batch_input_lengths.append(new_input_length)
|
||||||
next_batch_max_sequence_length = max(
|
next_batch_max_input_length = max(
|
||||||
next_batch_max_sequence_length, new_input_length
|
next_batch_max_input_length, new_input_length
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
|
@ -506,7 +506,7 @@ class CausalLM(Model):
|
||||||
next_token_choosers=next_batch_next_token_choosers,
|
next_token_choosers=next_batch_next_token_choosers,
|
||||||
stopping_criterias=next_batch_stopping_criterias,
|
stopping_criterias=next_batch_stopping_criterias,
|
||||||
size=next_batch_size,
|
size=next_batch_size,
|
||||||
max_sequence_length=next_batch_max_sequence_length,
|
max_input_length=next_batch_max_input_length,
|
||||||
padding_right_offset=batch.padding_right_offset - 1,
|
padding_right_offset=batch.padding_right_offset - 1,
|
||||||
keys_head_dim_last=batch.keys_head_dim_last,
|
keys_head_dim_last=batch.keys_head_dim_last,
|
||||||
)
|
)
|
||||||
|
|
|
@ -68,17 +68,14 @@ class Seq2SeqLMBatch(Batch):
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
input_lengths = []
|
|
||||||
|
|
||||||
decoder_input_ids = []
|
decoder_input_ids = []
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
max_input_length = 0
|
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
input_lengths.append(r.input_length)
|
|
||||||
# Decoder sequence only contains the bos_token
|
# Decoder sequence only contains the bos_token
|
||||||
decoder_input_ids.append(tokenizer.bos_token_id)
|
decoder_input_ids.append(tokenizer.bos_token_id)
|
||||||
decoder_input_lengths.append(1)
|
decoder_input_lengths.append(1)
|
||||||
|
@ -87,7 +84,6 @@ class Seq2SeqLMBatch(Batch):
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
max_input_length = max(max_input_length, r.input_length)
|
|
||||||
padding_right_offset = max(
|
padding_right_offset = max(
|
||||||
padding_right_offset, stopping_criteria.max_new_tokens
|
padding_right_offset, stopping_criteria.max_new_tokens
|
||||||
)
|
)
|
||||||
|
@ -99,6 +95,10 @@ class Seq2SeqLMBatch(Batch):
|
||||||
padding=True,
|
padding=True,
|
||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
|
max_input_length = input_lengths.max()
|
||||||
|
|
||||||
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
||||||
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
|
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
|
||||||
|
|
||||||
|
@ -111,12 +111,12 @@ class Seq2SeqLMBatch(Batch):
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
encoder_last_hidden_state=None,
|
encoder_last_hidden_state=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths.tolist(),
|
||||||
decoder_input_lengths=decoder_input_lengths,
|
decoder_input_lengths=decoder_input_lengths,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=len(pb.requests),
|
size=len(pb.requests),
|
||||||
max_input_length=max(input_lengths),
|
max_input_length=max_input_length.item(),
|
||||||
max_decoder_input_length=1,
|
max_decoder_input_length=1,
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue