Further fixes. (#2426)

* Further fixes.

* Update the conftest to allow NaN (first logprob).

* Fix the condition.
This commit is contained in:
Nicolas Patry 2024-08-16 13:21:44 +02:00 committed by GitHub
parent 99b662f8c2
commit c7ab1810d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 168 additions and 164 deletions

View File

@ -40,14 +40,14 @@ RUN cargo build --profile release-opt
# Python builder # Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099 # NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
ARG PYTORCH_VERSION=2.4.0 ARG PYTORCH_VERSION=2.4.0
ARG PYTHON_VERSION=3.10 ARG PYTHON_VERSION=3.10
# Keep in sync with `server/pyproject.toml # Keep in sync with `server/pyproject.toml
ARG CUDA_VERSION=12.1 ARG CUDA_VERSION=12.4
ARG MAMBA_VERSION=24.3.0-0 ARG MAMBA_VERSION=24.3.0-0
ARG CUDA_CHANNEL=nvidia ARG CUDA_CHANNEL=nvidia
ARG INSTALL_CHANNEL=pytorch ARG INSTALL_CHANNEL=pytorch

View File

@ -118,6 +118,7 @@ class ResponseComparator(JSONSnapshotExtension):
and token.text == other.text and token.text == other.text
and ( and (
self.ignore_logprob self.ignore_logprob
or (token.logprob == other.logprob and token.logprob is None)
or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol) or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
) )
and token.special == other.special and token.special == other.special

View File

@ -12,12 +12,12 @@
}, },
{ {
"id": 2323, "id": 2323,
"logprob": -9.421875, "logprob": -9.5625,
"text": "Test" "text": "Test"
}, },
{ {
"id": 1715, "id": 1715,
"logprob": -10.546875, "logprob": -10.375,
"text": " request" "text": " request"
} }
], ],
@ -25,61 +25,61 @@
"tokens": [ "tokens": [
{ {
"id": 369, "id": 369,
"logprob": -2.1816406, "logprob": -2.15625,
"special": false, "special": false,
"text": " for" "text": " for"
}, },
{ {
"id": 279, "id": 279,
"logprob": -2.6992188, "logprob": -2.703125,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 220, "id": 220,
"logprob": -3.6308594, "logprob": -3.640625,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 679, "id": 679,
"logprob": -1.7988281, "logprob": -1.703125,
"special": false, "special": false,
"text": "201" "text": "201"
}, },
{ {
"id": 24, "id": 24,
"logprob": -1.3535156, "logprob": -1.421875,
"special": false, "special": false,
"text": "9" "text": "9"
}, },
{ {
"id": 12, "id": 12,
"logprob": -2.0058594, "logprob": -2.03125,
"special": false, "special": false,
"text": "-" "text": "-"
}, },
{ {
"id": 2366, "id": 2366,
"logprob": -0.45410156, "logprob": -0.49023438,
"special": false, "special": false,
"text": "202" "text": "202"
}, },
{ {
"id": 15, "id": 15,
"logprob": -0.037109375, "logprob": -0.041503906,
"special": false, "special": false,
"text": "0" "text": "0"
}, },
{ {
"id": 2978, "id": 2978,
"logprob": -0.8095703, "logprob": -0.87109375,
"special": false, "special": false,
"text": " school" "text": " school"
}, },
{ {
"id": 1060, "id": 1060,
"logprob": -0.013053894, "logprob": -0.012939453,
"special": false, "special": false,
"text": " year" "text": " year"
} }
@ -101,12 +101,12 @@
}, },
{ {
"id": 2323, "id": 2323,
"logprob": -9.421875, "logprob": -9.5625,
"text": "Test" "text": "Test"
}, },
{ {
"id": 1715, "id": 1715,
"logprob": -10.546875, "logprob": -10.375,
"text": " request" "text": " request"
} }
], ],
@ -114,61 +114,61 @@
"tokens": [ "tokens": [
{ {
"id": 369, "id": 369,
"logprob": -2.1816406, "logprob": -2.15625,
"special": false, "special": false,
"text": " for" "text": " for"
}, },
{ {
"id": 279, "id": 279,
"logprob": -2.6992188, "logprob": -2.703125,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 220, "id": 220,
"logprob": -3.6308594, "logprob": -3.640625,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 679, "id": 679,
"logprob": -1.7988281, "logprob": -1.703125,
"special": false, "special": false,
"text": "201" "text": "201"
}, },
{ {
"id": 24, "id": 24,
"logprob": -1.3535156, "logprob": -1.421875,
"special": false, "special": false,
"text": "9" "text": "9"
}, },
{ {
"id": 12, "id": 12,
"logprob": -2.0058594, "logprob": -2.03125,
"special": false, "special": false,
"text": "-" "text": "-"
}, },
{ {
"id": 2366, "id": 2366,
"logprob": -0.45410156, "logprob": -0.49023438,
"special": false, "special": false,
"text": "202" "text": "202"
}, },
{ {
"id": 15, "id": 15,
"logprob": -0.037109375, "logprob": -0.041503906,
"special": false, "special": false,
"text": "0" "text": "0"
}, },
{ {
"id": 2978, "id": 2978,
"logprob": -0.8095703, "logprob": -0.87109375,
"special": false, "special": false,
"text": " school" "text": " school"
}, },
{ {
"id": 1060, "id": 1060,
"logprob": -0.013053894, "logprob": -0.012939453,
"special": false, "special": false,
"text": " year" "text": " year"
} }
@ -190,12 +190,12 @@
}, },
{ {
"id": 2323, "id": 2323,
"logprob": -9.421875, "logprob": -9.5625,
"text": "Test" "text": "Test"
}, },
{ {
"id": 1715, "id": 1715,
"logprob": -10.546875, "logprob": -10.375,
"text": " request" "text": " request"
} }
], ],
@ -203,61 +203,61 @@
"tokens": [ "tokens": [
{ {
"id": 369, "id": 369,
"logprob": -2.1816406, "logprob": -2.15625,
"special": false, "special": false,
"text": " for" "text": " for"
}, },
{ {
"id": 279, "id": 279,
"logprob": -2.6992188, "logprob": -2.703125,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 220, "id": 220,
"logprob": -3.6308594, "logprob": -3.640625,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 679, "id": 679,
"logprob": -1.7988281, "logprob": -1.703125,
"special": false, "special": false,
"text": "201" "text": "201"
}, },
{ {
"id": 24, "id": 24,
"logprob": -1.3535156, "logprob": -1.421875,
"special": false, "special": false,
"text": "9" "text": "9"
}, },
{ {
"id": 12, "id": 12,
"logprob": -2.0058594, "logprob": -2.03125,
"special": false, "special": false,
"text": "-" "text": "-"
}, },
{ {
"id": 2366, "id": 2366,
"logprob": -0.45410156, "logprob": -0.49023438,
"special": false, "special": false,
"text": "202" "text": "202"
}, },
{ {
"id": 15, "id": 15,
"logprob": -0.037109375, "logprob": -0.041503906,
"special": false, "special": false,
"text": "0" "text": "0"
}, },
{ {
"id": 2978, "id": 2978,
"logprob": -0.8095703, "logprob": -0.87109375,
"special": false, "special": false,
"text": " school" "text": " school"
}, },
{ {
"id": 1060, "id": 1060,
"logprob": -0.013053894, "logprob": -0.012939453,
"special": false, "special": false,
"text": " year" "text": " year"
} }
@ -279,12 +279,12 @@
}, },
{ {
"id": 2323, "id": 2323,
"logprob": -9.421875, "logprob": -9.5625,
"text": "Test" "text": "Test"
}, },
{ {
"id": 1715, "id": 1715,
"logprob": -10.546875, "logprob": -10.375,
"text": " request" "text": " request"
} }
], ],
@ -292,61 +292,61 @@
"tokens": [ "tokens": [
{ {
"id": 369, "id": 369,
"logprob": -2.1816406, "logprob": -2.15625,
"special": false, "special": false,
"text": " for" "text": " for"
}, },
{ {
"id": 279, "id": 279,
"logprob": -2.6992188, "logprob": -2.703125,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 220, "id": 220,
"logprob": -3.6308594, "logprob": -3.640625,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 679, "id": 679,
"logprob": -1.7988281, "logprob": -1.703125,
"special": false, "special": false,
"text": "201" "text": "201"
}, },
{ {
"id": 24, "id": 24,
"logprob": -1.3535156, "logprob": -1.421875,
"special": false, "special": false,
"text": "9" "text": "9"
}, },
{ {
"id": 12, "id": 12,
"logprob": -2.0058594, "logprob": -2.03125,
"special": false, "special": false,
"text": "-" "text": "-"
}, },
{ {
"id": 2366, "id": 2366,
"logprob": -0.45410156, "logprob": -0.49023438,
"special": false, "special": false,
"text": "202" "text": "202"
}, },
{ {
"id": 15, "id": 15,
"logprob": -0.037109375, "logprob": -0.041503906,
"special": false, "special": false,
"text": "0" "text": "0"
}, },
{ {
"id": 2978, "id": 2978,
"logprob": -0.8095703, "logprob": -0.87109375,
"special": false, "special": false,
"text": " school" "text": " school"
}, },
{ {
"id": 1060, "id": 1060,
"logprob": -0.013053894, "logprob": -0.012939453,
"special": false, "special": false,
"text": " year" "text": " year"
} }

View File

@ -17,37 +17,37 @@
}, },
{ {
"id": 21017, "id": 21017,
"logprob": -8.859375, "logprob": -8.8515625,
"text": "ometric" "text": "ometric"
}, },
{ {
"id": 81, "id": 81,
"logprob": -0.21826172, "logprob": -0.22033691,
"text": "_" "text": "_"
}, },
{ {
"id": 6009, "id": 6009,
"logprob": -1.3085938, "logprob": -1.2939453,
"text": "mean" "text": "mean"
}, },
{ {
"id": 26, "id": 26,
"logprob": -0.2548828, "logprob": -0.25268555,
"text": "(" "text": "("
}, },
{ {
"id": 62, "id": 62,
"logprob": -4.8007812, "logprob": -4.796875,
"text": "L" "text": "L"
}, },
{ {
"id": 44, "id": 44,
"logprob": -3.7871094, "logprob": -3.796875,
"text": ":" "text": ":"
}, },
{ {
"id": 1682, "id": 1682,
"logprob": -0.81152344, "logprob": -0.8066406,
"text": " List" "text": " List"
}, },
{ {
@ -57,7 +57,7 @@
}, },
{ {
"id": 1808, "id": 1808,
"logprob": -0.46313477, "logprob": -0.46166992,
"text": "float" "text": "float"
}, },
{ {
@ -70,7 +70,7 @@
"tokens": [ "tokens": [
{ {
"id": 284, "id": 284,
"logprob": -0.046936035, "logprob": -0.046844482,
"special": false, "special": false,
"text": "\n " "text": "\n "
}, },
@ -103,22 +103,22 @@
}, },
{ {
"id": 21017, "id": 21017,
"logprob": -8.859375, "logprob": -8.8515625,
"text": "ometric" "text": "ometric"
}, },
{ {
"id": 81, "id": 81,
"logprob": -0.21899414, "logprob": -0.21826172,
"text": "_" "text": "_"
}, },
{ {
"id": 6009, "id": 6009,
"logprob": -1.3105469, "logprob": -1.2871094,
"text": "mean" "text": "mean"
}, },
{ {
"id": 26, "id": 26,
"logprob": -0.25561523, "logprob": -0.25390625,
"text": "(" "text": "("
}, },
{ {
@ -131,92 +131,6 @@
"logprob": -3.7890625, "logprob": -3.7890625,
"text": ":" "text": ":"
}, },
{
"id": 1682,
"logprob": -0.80615234,
"text": " List"
},
{
"id": 77,
"logprob": -0.22375488,
"text": "["
},
{
"id": 1808,
"logprob": -0.46801758,
"text": "float"
},
{
"id": 10794,
"logprob": -3.0253906,
"text": "]):"
}
],
"seed": null,
"tokens": [
{
"id": 284,
"logprob": -0.046447754,
"special": false,
"text": "\n "
},
{
"id": 0,
"logprob": null,
"special": true,
"text": "<|endoftext|>"
}
],
"top_tokens": null
},
"generated_text": "\n "
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 2,
"prefill": [
{
"id": 589,
"logprob": null,
"text": "def"
},
{
"id": 3226,
"logprob": -8.9453125,
"text": " ge"
},
{
"id": 21017,
"logprob": -8.859375,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.2163086,
"text": "_"
},
{
"id": 6009,
"logprob": -1.2958984,
"text": "mean"
},
{
"id": 26,
"logprob": -0.2529297,
"text": "("
},
{
"id": 62,
"logprob": -4.796875,
"text": "L"
},
{
"id": 44,
"logprob": -3.7910156,
"text": ":"
},
{ {
"id": 1682, "id": 1682,
"logprob": -0.8076172, "logprob": -0.8076172,
@ -224,12 +138,12 @@
}, },
{ {
"id": 77, "id": 77,
"logprob": -0.22375488, "logprob": -0.22302246,
"text": "[" "text": "["
}, },
{ {
"id": 1808, "id": 1808,
"logprob": -0.46655273, "logprob": -0.46435547,
"text": "float" "text": "float"
}, },
{ {
@ -242,7 +156,7 @@
"tokens": [ "tokens": [
{ {
"id": 284, "id": 284,
"logprob": -0.0463562, "logprob": -0.046722412,
"special": false, "special": false,
"text": "\n " "text": "\n "
}, },
@ -275,47 +189,133 @@
}, },
{ {
"id": 21017, "id": 21017,
"logprob": -8.859375, "logprob": -8.8515625,
"text": "ometric" "text": "ometric"
}, },
{ {
"id": 81, "id": 81,
"logprob": -0.21862793, "logprob": -0.21813965,
"text": "_" "text": "_"
}, },
{ {
"id": 6009, "id": 6009,
"logprob": -1.3095703, "logprob": -1.2744141,
"text": "mean" "text": "mean"
}, },
{ {
"id": 26, "id": 26,
"logprob": -0.25512695, "logprob": -0.2512207,
"text": "(" "text": "("
}, },
{ {
"id": 62, "id": 62,
"logprob": -4.796875, "logprob": -4.8046875,
"text": "L" "text": "L"
}, },
{ {
"id": 44, "id": 44,
"logprob": -3.7890625, "logprob": -3.7851562,
"text": ":" "text": ":"
}, },
{ {
"id": 1682, "id": 1682,
"logprob": -0.79589844, "logprob": -0.81396484,
"text": " List" "text": " List"
}, },
{ {
"id": 77, "id": 77,
"logprob": -0.22692871, "logprob": -0.22570801,
"text": "[" "text": "["
}, },
{ {
"id": 1808, "id": 1808,
"logprob": -0.46801758, "logprob": -0.46044922,
"text": "float"
},
{
"id": 10794,
"logprob": -3.0234375,
"text": "]):"
}
],
"seed": null,
"tokens": [
{
"id": 284,
"logprob": -0.04650879,
"special": false,
"text": "\n "
},
{
"id": 0,
"logprob": null,
"special": true,
"text": "<|endoftext|>"
}
],
"top_tokens": null
},
"generated_text": "\n "
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 2,
"prefill": [
{
"id": 589,
"logprob": null,
"text": "def"
},
{
"id": 3226,
"logprob": -8.9453125,
"text": " ge"
},
{
"id": 21017,
"logprob": -8.8515625,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.21960449,
"text": "_"
},
{
"id": 6009,
"logprob": -1.2890625,
"text": "mean"
},
{
"id": 26,
"logprob": -0.25073242,
"text": "("
},
{
"id": 62,
"logprob": -4.8085938,
"text": "L"
},
{
"id": 44,
"logprob": -3.8046875,
"text": ":"
},
{
"id": 1682,
"logprob": -0.8071289,
"text": " List"
},
{
"id": 77,
"logprob": -0.22570801,
"text": "["
},
{
"id": 1808,
"logprob": -0.46118164,
"text": "float" "text": "float"
}, },
{ {
@ -328,7 +328,7 @@
"tokens": [ "tokens": [
{ {
"id": 284, "id": 284,
"logprob": -0.04638672, "logprob": -0.046539307,
"special": false, "special": false,
"text": "\n " "text": "\n "
}, },

View File

@ -21,6 +21,7 @@ async def test_flash_llama_fp8(flash_llama_fp8, response_snapshot):
"Test request", max_new_tokens=10, decoder_input_details=True "Test request", max_new_tokens=10, decoder_input_details=True
) )
assert response.generated_text == " for the 2019-2020 school year"
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response == response_snapshot assert response == response_snapshot
@ -57,6 +58,8 @@ async def test_flash_llama_fp8_load(flash_llama_fp8, generate_load, response_sna
) )
assert len(responses) == 4 assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]) assert responses[0].generated_text == " for the 2019-2020 school year"
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"Different messages : {[r.generated_text for r in responses]}"
assert responses == response_snapshot assert responses == response_snapshot