diff --git a/integration-tests/models/test_flash_gemma_gptq.py b/integration-tests/models/test_flash_gemma_gptq.py index 7ed339f4..8ac5f5a1 100644 --- a/integration-tests/models/test_flash_gemma_gptq.py +++ b/integration-tests/models/test_flash_gemma_gptq.py @@ -15,18 +15,20 @@ async def flash_gemma_gptq(flash_gemma_gptq_handle): @pytest.mark.asyncio @pytest.mark.private -async def test_flash_gemma_gptq(flash_gemma_gptq, response_snapshot): +async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot): response = await flash_gemma_gptq.generate( "Test request", max_new_tokens=10, decoder_input_details=True ) assert response.details.generated_tokens == 10 - assert response == response_snapshot + assert response == ignore_logprob_response_snapshot @pytest.mark.asyncio @pytest.mark.private -async def test_flash_gemma_gptq_all_params(flash_gemma_gptq, response_snapshot): +async def test_flash_gemma_gptq_all_params( + flash_gemma_gptq, ignore_logprob_response_snapshot +): response = await flash_gemma_gptq.generate( "Test request", max_new_tokens=10, @@ -44,13 +46,13 @@ async def test_flash_gemma_gptq_all_params(flash_gemma_gptq, response_snapshot): ) assert response.details.generated_tokens == 10 - assert response == response_snapshot + assert response == ignore_logprob_response_snapshot @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma_gptq_load( - flash_gemma_gptq, generate_load, response_snapshot + flash_gemma_gptq, generate_load, ignore_logprob_response_snapshot ): responses = await generate_load( flash_gemma_gptq, "Test request", max_new_tokens=10, n=4 @@ -59,4 +61,4 @@ async def test_flash_gemma_gptq_load( assert len(responses) == 4 assert all([r.generated_text == responses[0].generated_text for r in responses]) - assert responses == response_snapshot + assert responses == ignore_logprob_response_snapshot