From 967ced2ff4565a5358d45a1372d32fbab113700b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Thu, 30 May 2024 07:10:10 +0000 Subject: [PATCH] Gemma GPTQ checks: skip logprob checks This test fails somewhat regularly due to non-determinism and this test is primarily to verify that we are loading a model which doesn't have `float16` as the default dtype correctly. --- integration-tests/models/test_flash_gemma_gptq.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/integration-tests/models/test_flash_gemma_gptq.py b/integration-tests/models/test_flash_gemma_gptq.py index 7ed339f4..8ac5f5a1 100644 --- a/integration-tests/models/test_flash_gemma_gptq.py +++ b/integration-tests/models/test_flash_gemma_gptq.py @@ -15,18 +15,20 @@ async def flash_gemma_gptq(flash_gemma_gptq_handle): @pytest.mark.asyncio @pytest.mark.private -async def test_flash_gemma_gptq(flash_gemma_gptq, response_snapshot): +async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot): response = await flash_gemma_gptq.generate( "Test request", max_new_tokens=10, decoder_input_details=True ) assert response.details.generated_tokens == 10 - assert response == response_snapshot + assert response == ignore_logprob_response_snapshot @pytest.mark.asyncio @pytest.mark.private -async def test_flash_gemma_gptq_all_params(flash_gemma_gptq, response_snapshot): +async def test_flash_gemma_gptq_all_params( + flash_gemma_gptq, ignore_logprob_response_snapshot +): response = await flash_gemma_gptq.generate( "Test request", max_new_tokens=10, @@ -44,13 +46,13 @@ async def test_flash_gemma_gptq_all_params(flash_gemma_gptq, response_snapshot): ) assert response.details.generated_tokens == 10 - assert response == response_snapshot + assert response == ignore_logprob_response_snapshot @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma_gptq_load( - flash_gemma_gptq, generate_load, response_snapshot + flash_gemma_gptq, generate_load, ignore_logprob_response_snapshot ): responses = await generate_load( flash_gemma_gptq, "Test request", max_new_tokens=10, n=4 @@ -59,4 +61,4 @@ async def test_flash_gemma_gptq_load( assert len(responses) == 4 assert all([r.generated_text == responses[0].generated_text for r in responses]) - assert responses == response_snapshot + assert responses == ignore_logprob_response_snapshot