Gemma GPTQ checks: skip logprob checks
This test fails somewhat regularly due to non-determinism and this test is primarily to verify that we are loading a model which doesn't have `float16` as the default dtype correctly.
This commit is contained in:
parent
36dd16017c
commit
967ced2ff4
|
@ -15,18 +15,20 @@ async def flash_gemma_gptq(flash_gemma_gptq_handle):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_gemma_gptq(flash_gemma_gptq, response_snapshot):
|
||||
async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot):
|
||||
response = await flash_gemma_gptq.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
assert response == ignore_logprob_response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_gemma_gptq_all_params(flash_gemma_gptq, response_snapshot):
|
||||
async def test_flash_gemma_gptq_all_params(
|
||||
flash_gemma_gptq, ignore_logprob_response_snapshot
|
||||
):
|
||||
response = await flash_gemma_gptq.generate(
|
||||
"Test request",
|
||||
max_new_tokens=10,
|
||||
|
@ -44,13 +46,13 @@ async def test_flash_gemma_gptq_all_params(flash_gemma_gptq, response_snapshot):
|
|||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
assert response == ignore_logprob_response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_gemma_gptq_load(
|
||||
flash_gemma_gptq, generate_load, response_snapshot
|
||||
flash_gemma_gptq, generate_load, ignore_logprob_response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_gemma_gptq, "Test request", max_new_tokens=10, n=4
|
||||
|
@ -59,4 +61,4 @@ async def test_flash_gemma_gptq_load(
|
|||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
||||
assert responses == ignore_logprob_response_snapshot
|
||||
|
|
Loading…
Reference in New Issue