diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 32ee6686..0daa5f41 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -197,7 +197,9 @@ def test_causal_lm_generate_token_completion_multi( # Copy stopping_criterias before filtering stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy() - next_batch = next_batch.filter([next_batch.requests[0].id]) + next_batch = next_batch.filter( + [generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])] + ) for _ in range( stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 @@ -306,7 +308,14 @@ def test_batch_concatenate( ) next_batch = next_batch.filter( - [next_batch.requests[0].id, next_batch.requests[1].id] + [ + generate_pb2.UpdatedRequest( + id=next_batch.requests[0].id, blocks=[], slots=[] + ), + generate_pb2.UpdatedRequest( + id=next_batch.requests[1].id, blocks=[], slots=[] + ), + ] ) for _ in range( @@ -330,7 +339,9 @@ def test_batch_concatenate( == default_bloom_batch.stopping_criterias[0].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[1].id]) + next_batch = next_batch.filter( + [generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])] + ) for _ in range( default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 6e6463bc..547da81f 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -198,7 +198,9 @@ def test_causal_lm_generate_token_completion_multi( default_multi_requests_causal_lm_batch.stopping_criterias.copy() ) - next_batch = next_batch.filter([next_batch.requests[0].id]) + next_batch = next_batch.filter( + [generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])] + ) for _ in range( stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 @@ -306,7 +308,14 @@ def test_batch_concatenate( ) next_batch = next_batch.filter( - [next_batch.requests[0].id, next_batch.requests[1].id] + [ + generate_pb2.UpdatedRequest( + id=next_batch.requests[0].id, blocks=[], slots=[] + ), + generate_pb2.UpdatedRequest( + id=next_batch.requests[1].id, blocks=[], slots=[] + ), + ] ) for _ in range( @@ -328,7 +337,16 @@ def test_batch_concatenate( == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[1].id]) + next_batch = next_batch.filter( + [ + generate_pb2.UpdatedRequest( + id=next_batch.requests[0].id, blocks=[], slots=[] + ), + generate_pb2.UpdatedRequest( + id=next_batch.requests[1].id, blocks=[], slots=[] + ), + ] + ) for _ in range( default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 943c3b08..17b5fa50 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -206,7 +206,9 @@ def test_seq2seq_lm_generate_token_completion_multi( ) assert generations[1].generated_text.generated_tokens == 5 - next_batch = next_batch.filter([next_batch.requests[0].id]) + next_batch = next_batch.filter( + [generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])] + ) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert len(generations) == len(next_batch) @@ -340,7 +342,14 @@ def test_batch_concatenate( assert generations[2].generated_text.generated_tokens == 5 next_batch = next_batch.filter( - [next_batch.requests[0].id, next_batch.requests[1].id] + [ + generate_pb2.UpdatedRequest( + id=next_batch.requests[0].id, blocks=[], slots=[] + ), + generate_pb2.UpdatedRequest( + id=next_batch.requests[1].id, blocks=[], slots=[] + ), + ] ) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) @@ -351,7 +360,9 @@ def test_batch_concatenate( assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id assert generations[0].generated_text.generated_tokens == 7 - next_batch = next_batch.filter([next_batch.requests[1].id]) + next_batch = next_batch.filter( + [generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])] + ) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None