fix tests

This commit is contained in:
OlivierDehaene 2024-06-12 18:54:25 +02:00
parent 05eb4dcb17
commit abe521204e
3 changed files with 12 additions and 12 deletions

View File

@ -199,7 +199,7 @@ def test_causal_lm_generate_token_completion_multi(
next_batch, _ = next_batch.filter(
default_bloom,
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[])],
[],
)
@ -312,8 +312,8 @@ def test_batch_concatenate(
next_batch, _ = next_batch.filter(
default_bloom,
[
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[]),
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]),
],
[],
)
@ -341,7 +341,7 @@ def test_batch_concatenate(
next_batch, _ = next_batch.filter(
default_bloom,
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])],
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[])],
[],
)

View File

@ -200,7 +200,7 @@ def test_causal_lm_generate_token_completion_multi(
next_batch, _ = next_batch.filter(
default_causal_lm,
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[])],
[],
)
@ -312,8 +312,8 @@ def test_batch_concatenate(
next_batch, _ = next_batch.filter(
default_causal_lm,
[
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[]),
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]),
],
[],
)
@ -340,7 +340,7 @@ def test_batch_concatenate(
next_batch, _ = next_batch.filter(
default_causal_lm,
[
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]),
],
[],
)

View File

@ -208,7 +208,7 @@ def test_seq2seq_lm_generate_token_completion_multi(
next_batch, _ = next_batch.filter(
default_seq2seq_lm,
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[])],
[],
)
@ -346,8 +346,8 @@ def test_batch_concatenate(
next_batch, _ = next_batch.filter(
default_seq2seq_lm,
[
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[]),
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]),
],
[],
)
@ -362,7 +362,7 @@ def test_batch_concatenate(
next_batch, _ = next_batch.filter(
default_seq2seq_lm,
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])],
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[])],
[],
)