fix python tests
This commit is contained in:
parent
51fa606875
commit
3c596983ba
|
@ -197,7 +197,9 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
# Copy stopping_criterias before filtering
|
||||
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id])
|
||||
next_batch = next_batch.filter(
|
||||
[generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])]
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
||||
|
@ -306,7 +308,14 @@ def test_batch_concatenate(
|
|||
)
|
||||
|
||||
next_batch = next_batch.filter(
|
||||
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||
[
|
||||
generate_pb2.UpdatedRequest(
|
||||
id=next_batch.requests[0].id, blocks=[], slots=[]
|
||||
),
|
||||
generate_pb2.UpdatedRequest(
|
||||
id=next_batch.requests[1].id, blocks=[], slots=[]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
|
@ -330,7 +339,9 @@ def test_batch_concatenate(
|
|||
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[1].id])
|
||||
next_batch = next_batch.filter(
|
||||
[generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])]
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
|
|
|
@ -198,7 +198,9 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id])
|
||||
next_batch = next_batch.filter(
|
||||
[generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])]
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
||||
|
@ -306,7 +308,14 @@ def test_batch_concatenate(
|
|||
)
|
||||
|
||||
next_batch = next_batch.filter(
|
||||
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||
[
|
||||
generate_pb2.UpdatedRequest(
|
||||
id=next_batch.requests[0].id, blocks=[], slots=[]
|
||||
),
|
||||
generate_pb2.UpdatedRequest(
|
||||
id=next_batch.requests[1].id, blocks=[], slots=[]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
|
@ -328,7 +337,16 @@ def test_batch_concatenate(
|
|||
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[1].id])
|
||||
next_batch = next_batch.filter(
|
||||
[
|
||||
generate_pb2.UpdatedRequest(
|
||||
id=next_batch.requests[0].id, blocks=[], slots=[]
|
||||
),
|
||||
generate_pb2.UpdatedRequest(
|
||||
id=next_batch.requests[1].id, blocks=[], slots=[]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
|
|
|
@ -206,7 +206,9 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
|||
)
|
||||
assert generations[1].generated_text.generated_tokens == 5
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id])
|
||||
next_batch = next_batch.filter(
|
||||
[generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])]
|
||||
)
|
||||
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
@ -340,7 +342,14 @@ def test_batch_concatenate(
|
|||
assert generations[2].generated_text.generated_tokens == 5
|
||||
|
||||
next_batch = next_batch.filter(
|
||||
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||
[
|
||||
generate_pb2.UpdatedRequest(
|
||||
id=next_batch.requests[0].id, blocks=[], slots=[]
|
||||
),
|
||||
generate_pb2.UpdatedRequest(
|
||||
id=next_batch.requests[1].id, blocks=[], slots=[]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
|
@ -351,7 +360,9 @@ def test_batch_concatenate(
|
|||
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
||||
assert generations[0].generated_text.generated_tokens == 7
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[1].id])
|
||||
next_batch = next_batch.filter(
|
||||
[generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])]
|
||||
)
|
||||
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
|
Loading…
Reference in New Issue