Gemma GPTQ checks: skip logprob checks

This test fails somewhat regularly due to non-determinism and this
test is primarily to verify that we are loading a model which doesn't
have `float16` as the default dtype correctly.
This commit is contained in:
Daniël de Kok 2024-05-30 07:10:10 +00:00 committed by Daniël de Kok
parent 36dd16017c
commit 967ced2ff4
1 changed files with 8 additions and 6 deletions

View File

@ -15,18 +15,20 @@ async def flash_gemma_gptq(flash_gemma_gptq_handle):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_gptq(flash_gemma_gptq, response_snapshot): async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot):
response = await flash_gemma_gptq.generate( response = await flash_gemma_gptq.generate(
"Test request", max_new_tokens=10, decoder_input_details=True "Test request", max_new_tokens=10, decoder_input_details=True
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response == response_snapshot assert response == ignore_logprob_response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_gptq_all_params(flash_gemma_gptq, response_snapshot): async def test_flash_gemma_gptq_all_params(
flash_gemma_gptq, ignore_logprob_response_snapshot
):
response = await flash_gemma_gptq.generate( response = await flash_gemma_gptq.generate(
"Test request", "Test request",
max_new_tokens=10, max_new_tokens=10,
@ -44,13 +46,13 @@ async def test_flash_gemma_gptq_all_params(flash_gemma_gptq, response_snapshot):
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response == response_snapshot assert response == ignore_logprob_response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_gptq_load( async def test_flash_gemma_gptq_load(
flash_gemma_gptq, generate_load, response_snapshot flash_gemma_gptq, generate_load, ignore_logprob_response_snapshot
): ):
responses = await generate_load( responses = await generate_load(
flash_gemma_gptq, "Test request", max_new_tokens=10, n=4 flash_gemma_gptq, "Test request", max_new_tokens=10, n=4
@ -59,4 +61,4 @@ async def test_flash_gemma_gptq_load(
assert len(responses) == 4 assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]) assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot assert responses == ignore_logprob_response_snapshot