fix tests
This commit is contained in:
parent
05eb4dcb17
commit
abe521204e
|
@ -199,7 +199,7 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
|
||||
next_batch, _ = next_batch.filter(
|
||||
default_bloom,
|
||||
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
|
||||
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[])],
|
||||
[],
|
||||
)
|
||||
|
||||
|
@ -312,8 +312,8 @@ def test_batch_concatenate(
|
|||
next_batch, _ = next_batch.filter(
|
||||
default_bloom,
|
||||
[
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[]),
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]),
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
@ -341,7 +341,7 @@ def test_batch_concatenate(
|
|||
|
||||
next_batch, _ = next_batch.filter(
|
||||
default_bloom,
|
||||
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])],
|
||||
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[])],
|
||||
[],
|
||||
)
|
||||
|
||||
|
|
|
@ -200,7 +200,7 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
|
||||
next_batch, _ = next_batch.filter(
|
||||
default_causal_lm,
|
||||
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
|
||||
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[])],
|
||||
[],
|
||||
)
|
||||
|
||||
|
@ -312,8 +312,8 @@ def test_batch_concatenate(
|
|||
next_batch, _ = next_batch.filter(
|
||||
default_causal_lm,
|
||||
[
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[]),
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]),
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
@ -340,7 +340,7 @@ def test_batch_concatenate(
|
|||
next_batch, _ = next_batch.filter(
|
||||
default_causal_lm,
|
||||
[
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]),
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
|
|
@ -208,7 +208,7 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
|||
|
||||
next_batch, _ = next_batch.filter(
|
||||
default_seq2seq_lm,
|
||||
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
|
||||
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[])],
|
||||
[],
|
||||
)
|
||||
|
||||
|
@ -346,8 +346,8 @@ def test_batch_concatenate(
|
|||
next_batch, _ = next_batch.filter(
|
||||
default_seq2seq_lm,
|
||||
[
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[]),
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]),
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
@ -362,7 +362,7 @@ def test_batch_concatenate(
|
|||
|
||||
next_batch, _ = next_batch.filter(
|
||||
default_seq2seq_lm,
|
||||
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])],
|
||||
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[])],
|
||||
[],
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue