Further fixes. (#2426)

* Further fixes.

* Update the conftest to allow NaN (first logprob).

* Fix the condition.
This commit is contained in:
Nicolas Patry 2024-08-16 13:21:44 +02:00 committed by GitHub
parent 99b662f8c2
commit c7ab1810d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 168 additions and 164 deletions

View File

@ -40,14 +40,14 @@ RUN cargo build --profile release-opt
# Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
ARG PYTORCH_VERSION=2.4.0
ARG PYTHON_VERSION=3.10
# Keep in sync with `server/pyproject.toml
ARG CUDA_VERSION=12.1
ARG CUDA_VERSION=12.4
ARG MAMBA_VERSION=24.3.0-0
ARG CUDA_CHANNEL=nvidia
ARG INSTALL_CHANNEL=pytorch

View File

@ -118,6 +118,7 @@ class ResponseComparator(JSONSnapshotExtension):
and token.text == other.text
and (
self.ignore_logprob
or (token.logprob == other.logprob and token.logprob is None)
or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
)
and token.special == other.special

View File

@ -12,12 +12,12 @@
},
{
"id": 2323,
"logprob": -9.421875,
"logprob": -9.5625,
"text": "Test"
},
{
"id": 1715,
"logprob": -10.546875,
"logprob": -10.375,
"text": " request"
}
],
@ -25,61 +25,61 @@
"tokens": [
{
"id": 369,
"logprob": -2.1816406,
"logprob": -2.15625,
"special": false,
"text": " for"
},
{
"id": 279,
"logprob": -2.6992188,
"logprob": -2.703125,
"special": false,
"text": " the"
},
{
"id": 220,
"logprob": -3.6308594,
"logprob": -3.640625,
"special": false,
"text": " "
},
{
"id": 679,
"logprob": -1.7988281,
"logprob": -1.703125,
"special": false,
"text": "201"
},
{
"id": 24,
"logprob": -1.3535156,
"logprob": -1.421875,
"special": false,
"text": "9"
},
{
"id": 12,
"logprob": -2.0058594,
"logprob": -2.03125,
"special": false,
"text": "-"
},
{
"id": 2366,
"logprob": -0.45410156,
"logprob": -0.49023438,
"special": false,
"text": "202"
},
{
"id": 15,
"logprob": -0.037109375,
"logprob": -0.041503906,
"special": false,
"text": "0"
},
{
"id": 2978,
"logprob": -0.8095703,
"logprob": -0.87109375,
"special": false,
"text": " school"
},
{
"id": 1060,
"logprob": -0.013053894,
"logprob": -0.012939453,
"special": false,
"text": " year"
}
@ -101,12 +101,12 @@
},
{
"id": 2323,
"logprob": -9.421875,
"logprob": -9.5625,
"text": "Test"
},
{
"id": 1715,
"logprob": -10.546875,
"logprob": -10.375,
"text": " request"
}
],
@ -114,61 +114,61 @@
"tokens": [
{
"id": 369,
"logprob": -2.1816406,
"logprob": -2.15625,
"special": false,
"text": " for"
},
{
"id": 279,
"logprob": -2.6992188,
"logprob": -2.703125,
"special": false,
"text": " the"
},
{
"id": 220,
"logprob": -3.6308594,
"logprob": -3.640625,
"special": false,
"text": " "
},
{
"id": 679,
"logprob": -1.7988281,
"logprob": -1.703125,
"special": false,
"text": "201"
},
{
"id": 24,
"logprob": -1.3535156,
"logprob": -1.421875,
"special": false,
"text": "9"
},
{
"id": 12,
"logprob": -2.0058594,
"logprob": -2.03125,
"special": false,
"text": "-"
},
{
"id": 2366,
"logprob": -0.45410156,
"logprob": -0.49023438,
"special": false,
"text": "202"
},
{
"id": 15,
"logprob": -0.037109375,
"logprob": -0.041503906,
"special": false,
"text": "0"
},
{
"id": 2978,
"logprob": -0.8095703,
"logprob": -0.87109375,
"special": false,
"text": " school"
},
{
"id": 1060,
"logprob": -0.013053894,
"logprob": -0.012939453,
"special": false,
"text": " year"
}
@ -190,12 +190,12 @@
},
{
"id": 2323,
"logprob": -9.421875,
"logprob": -9.5625,
"text": "Test"
},
{
"id": 1715,
"logprob": -10.546875,
"logprob": -10.375,
"text": " request"
}
],
@ -203,61 +203,61 @@
"tokens": [
{
"id": 369,
"logprob": -2.1816406,
"logprob": -2.15625,
"special": false,
"text": " for"
},
{
"id": 279,
"logprob": -2.6992188,
"logprob": -2.703125,
"special": false,
"text": " the"
},
{
"id": 220,
"logprob": -3.6308594,
"logprob": -3.640625,
"special": false,
"text": " "
},
{
"id": 679,
"logprob": -1.7988281,
"logprob": -1.703125,
"special": false,
"text": "201"
},
{
"id": 24,
"logprob": -1.3535156,
"logprob": -1.421875,
"special": false,
"text": "9"
},
{
"id": 12,
"logprob": -2.0058594,
"logprob": -2.03125,
"special": false,
"text": "-"
},
{
"id": 2366,
"logprob": -0.45410156,
"logprob": -0.49023438,
"special": false,
"text": "202"
},
{
"id": 15,
"logprob": -0.037109375,
"logprob": -0.041503906,
"special": false,
"text": "0"
},
{
"id": 2978,
"logprob": -0.8095703,
"logprob": -0.87109375,
"special": false,
"text": " school"
},
{
"id": 1060,
"logprob": -0.013053894,
"logprob": -0.012939453,
"special": false,
"text": " year"
}
@ -279,12 +279,12 @@
},
{
"id": 2323,
"logprob": -9.421875,
"logprob": -9.5625,
"text": "Test"
},
{
"id": 1715,
"logprob": -10.546875,
"logprob": -10.375,
"text": " request"
}
],
@ -292,61 +292,61 @@
"tokens": [
{
"id": 369,
"logprob": -2.1816406,
"logprob": -2.15625,
"special": false,
"text": " for"
},
{
"id": 279,
"logprob": -2.6992188,
"logprob": -2.703125,
"special": false,
"text": " the"
},
{
"id": 220,
"logprob": -3.6308594,
"logprob": -3.640625,
"special": false,
"text": " "
},
{
"id": 679,
"logprob": -1.7988281,
"logprob": -1.703125,
"special": false,
"text": "201"
},
{
"id": 24,
"logprob": -1.3535156,
"logprob": -1.421875,
"special": false,
"text": "9"
},
{
"id": 12,
"logprob": -2.0058594,
"logprob": -2.03125,
"special": false,
"text": "-"
},
{
"id": 2366,
"logprob": -0.45410156,
"logprob": -0.49023438,
"special": false,
"text": "202"
},
{
"id": 15,
"logprob": -0.037109375,
"logprob": -0.041503906,
"special": false,
"text": "0"
},
{
"id": 2978,
"logprob": -0.8095703,
"logprob": -0.87109375,
"special": false,
"text": " school"
},
{
"id": 1060,
"logprob": -0.013053894,
"logprob": -0.012939453,
"special": false,
"text": " year"
}

View File

@ -17,37 +17,37 @@
},
{
"id": 21017,
"logprob": -8.859375,
"logprob": -8.8515625,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.21826172,
"logprob": -0.22033691,
"text": "_"
},
{
"id": 6009,
"logprob": -1.3085938,
"logprob": -1.2939453,
"text": "mean"
},
{
"id": 26,
"logprob": -0.2548828,
"logprob": -0.25268555,
"text": "("
},
{
"id": 62,
"logprob": -4.8007812,
"logprob": -4.796875,
"text": "L"
},
{
"id": 44,
"logprob": -3.7871094,
"logprob": -3.796875,
"text": ":"
},
{
"id": 1682,
"logprob": -0.81152344,
"logprob": -0.8066406,
"text": " List"
},
{
@ -57,7 +57,7 @@
},
{
"id": 1808,
"logprob": -0.46313477,
"logprob": -0.46166992,
"text": "float"
},
{
@ -70,7 +70,7 @@
"tokens": [
{
"id": 284,
"logprob": -0.046936035,
"logprob": -0.046844482,
"special": false,
"text": "\n "
},
@ -103,22 +103,22 @@
},
{
"id": 21017,
"logprob": -8.859375,
"logprob": -8.8515625,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.21899414,
"logprob": -0.21826172,
"text": "_"
},
{
"id": 6009,
"logprob": -1.3105469,
"logprob": -1.2871094,
"text": "mean"
},
{
"id": 26,
"logprob": -0.25561523,
"logprob": -0.25390625,
"text": "("
},
{
@ -131,92 +131,6 @@
"logprob": -3.7890625,
"text": ":"
},
{
"id": 1682,
"logprob": -0.80615234,
"text": " List"
},
{
"id": 77,
"logprob": -0.22375488,
"text": "["
},
{
"id": 1808,
"logprob": -0.46801758,
"text": "float"
},
{
"id": 10794,
"logprob": -3.0253906,
"text": "]):"
}
],
"seed": null,
"tokens": [
{
"id": 284,
"logprob": -0.046447754,
"special": false,
"text": "\n "
},
{
"id": 0,
"logprob": null,
"special": true,
"text": "<|endoftext|>"
}
],
"top_tokens": null
},
"generated_text": "\n "
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 2,
"prefill": [
{
"id": 589,
"logprob": null,
"text": "def"
},
{
"id": 3226,
"logprob": -8.9453125,
"text": " ge"
},
{
"id": 21017,
"logprob": -8.859375,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.2163086,
"text": "_"
},
{
"id": 6009,
"logprob": -1.2958984,
"text": "mean"
},
{
"id": 26,
"logprob": -0.2529297,
"text": "("
},
{
"id": 62,
"logprob": -4.796875,
"text": "L"
},
{
"id": 44,
"logprob": -3.7910156,
"text": ":"
},
{
"id": 1682,
"logprob": -0.8076172,
@ -224,12 +138,12 @@
},
{
"id": 77,
"logprob": -0.22375488,
"logprob": -0.22302246,
"text": "["
},
{
"id": 1808,
"logprob": -0.46655273,
"logprob": -0.46435547,
"text": "float"
},
{
@ -242,7 +156,7 @@
"tokens": [
{
"id": 284,
"logprob": -0.0463562,
"logprob": -0.046722412,
"special": false,
"text": "\n "
},
@ -275,47 +189,133 @@
},
{
"id": 21017,
"logprob": -8.859375,
"logprob": -8.8515625,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.21862793,
"logprob": -0.21813965,
"text": "_"
},
{
"id": 6009,
"logprob": -1.3095703,
"logprob": -1.2744141,
"text": "mean"
},
{
"id": 26,
"logprob": -0.25512695,
"logprob": -0.2512207,
"text": "("
},
{
"id": 62,
"logprob": -4.796875,
"logprob": -4.8046875,
"text": "L"
},
{
"id": 44,
"logprob": -3.7890625,
"logprob": -3.7851562,
"text": ":"
},
{
"id": 1682,
"logprob": -0.79589844,
"logprob": -0.81396484,
"text": " List"
},
{
"id": 77,
"logprob": -0.22692871,
"logprob": -0.22570801,
"text": "["
},
{
"id": 1808,
"logprob": -0.46801758,
"logprob": -0.46044922,
"text": "float"
},
{
"id": 10794,
"logprob": -3.0234375,
"text": "]):"
}
],
"seed": null,
"tokens": [
{
"id": 284,
"logprob": -0.04650879,
"special": false,
"text": "\n "
},
{
"id": 0,
"logprob": null,
"special": true,
"text": "<|endoftext|>"
}
],
"top_tokens": null
},
"generated_text": "\n "
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 2,
"prefill": [
{
"id": 589,
"logprob": null,
"text": "def"
},
{
"id": 3226,
"logprob": -8.9453125,
"text": " ge"
},
{
"id": 21017,
"logprob": -8.8515625,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.21960449,
"text": "_"
},
{
"id": 6009,
"logprob": -1.2890625,
"text": "mean"
},
{
"id": 26,
"logprob": -0.25073242,
"text": "("
},
{
"id": 62,
"logprob": -4.8085938,
"text": "L"
},
{
"id": 44,
"logprob": -3.8046875,
"text": ":"
},
{
"id": 1682,
"logprob": -0.8071289,
"text": " List"
},
{
"id": 77,
"logprob": -0.22570801,
"text": "["
},
{
"id": 1808,
"logprob": -0.46118164,
"text": "float"
},
{
@ -328,7 +328,7 @@
"tokens": [
{
"id": 284,
"logprob": -0.04638672,
"logprob": -0.046539307,
"special": false,
"text": "\n "
},

View File

@ -21,6 +21,7 @@ async def test_flash_llama_fp8(flash_llama_fp8, response_snapshot):
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.generated_text == " for the 2019-2020 school year"
assert response.details.generated_tokens == 10
assert response == response_snapshot
@ -57,6 +58,8 @@ async def test_flash_llama_fp8_load(flash_llama_fp8, generate_load, response_sna
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses[0].generated_text == " for the 2019-2020 school year"
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"Different messages : {[r.generated_text for r in responses]}"
assert responses == response_snapshot