Softcapping for gemma2. (#2273)

* Softcapping for gemma2.

* Less clutter.

* No access to transformers config, only config_dict here.

* 0.0 is the null value in the C++ API.
This commit is contained in:
Nicolas Patry 2024-07-22 18:27:10 +02:00 committed by GitHub
parent 4844ff790a
commit 6aeb669072
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 2012 additions and 615 deletions

View File

@ -0,0 +1,254 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 106,
"logprob": -47.25,
"text": "<start_of_turn>"
},
{
"id": 1645,
"logprob": -18.875,
"text": "user"
},
{
"id": 235292,
"logprob": -7.15625,
"text": ":"
},
{
"id": 108,
"logprob": -4.78125,
"text": "\n"
},
{
"id": 5559,
"logprob": -10.0,
"text": "Write"
},
{
"id": 476,
"logprob": -0.1171875,
"text": " a"
},
{
"id": 19592,
"logprob": -2.46875,
"text": " poem"
},
{
"id": 577,
"logprob": -5.84375,
"text": " to"
},
{
"id": 1707,
"logprob": -6.375,
"text": " help"
},
{
"id": 682,
"logprob": -2.125,
"text": " me"
},
{
"id": 5434,
"logprob": -1.546875,
"text": " remember"
},
{
"id": 573,
"logprob": -0.62890625,
"text": " the"
},
{
"id": 1370,
"logprob": -6.65625,
"text": " first"
},
{
"id": 235248,
"logprob": -1.84375,
"text": " "
},
{
"id": 235274,
"logprob": -0.45117188,
"text": "1"
},
{
"id": 235276,
"logprob": -0.07421875,
"text": "0"
},
{
"id": 6635,
"logprob": -2.109375,
"text": " elements"
},
{
"id": 611,
"logprob": -0.4140625,
"text": " on"
},
{
"id": 573,
"logprob": -0.0009536743,
"text": " the"
},
{
"id": 26163,
"logprob": -0.033203125,
"text": " periodic"
},
{
"id": 3037,
"logprob": -0.0002670288,
"text": " table"
},
{
"id": 235269,
"logprob": -4.75,
"text": ","
},
{
"id": 7385,
"logprob": -11.625,
"text": " giving"
},
{
"id": 1853,
"logprob": -4.875,
"text": " each"
},
{
"id": 5356,
"logprob": -0.38867188,
"text": " element"
},
{
"id": 1277,
"logprob": -3.65625,
"text": " its"
},
{
"id": 1997,
"logprob": -4.4375,
"text": " own"
},
{
"id": 2017,
"logprob": -0.29882812,
"text": " line"
},
{
"id": 235265,
"logprob": -0.16699219,
"text": "."
},
{
"id": 107,
"logprob": -25.625,
"text": "<end_of_turn>"
},
{
"id": 108,
"logprob": -6.75,
"text": "\n"
},
{
"id": 106,
"logprob": -39.5,
"text": "<start_of_turn>"
},
{
"id": 2516,
"logprob": -32.5,
"text": "model"
},
{
"id": 235292,
"logprob": -10.125,
"text": ":"
},
{
"id": 108,
"logprob": -3.421875,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 688,
"logprob": -0.546875,
"special": false,
"text": "**"
},
{
"id": 103889,
"logprob": -0.49023438,
"special": false,
"text": "Hydrogen"
},
{
"id": 190213,
"logprob": -0.48632812,
"special": false,
"text": "**,"
},
{
"id": 2611,
"logprob": -0.58203125,
"special": false,
"text": " light"
},
{
"id": 578,
"logprob": -0.099121094,
"special": false,
"text": " and"
},
{
"id": 2223,
"logprob": -1.078125,
"special": false,
"text": " free"
},
{
"id": 235269,
"logprob": -0.025756836,
"special": false,
"text": ","
},
{
"id": 108,
"logprob": -0.29101562,
"special": false,
"text": "\n"
},
{
"id": 688,
"logprob": -0.0035858154,
"special": false,
"text": "**"
},
{
"id": 1949,
"logprob": -4.1007996e-05,
"special": false,
"text": "He"
}
],
"top_tokens": null
},
"generated_text": "**Hydrogen**, light and free,\n**He"
}

View File

@ -0,0 +1,46 @@
import pytest
@pytest.fixture(scope="module")
def flash_gemma2_handle(launcher):
with launcher("google/gemma-2-9b-it", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_gemma2(flash_gemma2_handle):
await flash_gemma2_handle.health(300)
return flash_gemma2_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma2(flash_gemma2, response_snapshot):
response = await flash_gemma2.generate(
"<start_of_turn>user:\nWrite a poem to help me remember the first 10 elements on the periodic table, giving each element its own line.<end_of_turn>\n<start_of_turn>model:\n",
max_new_tokens=10,
decoder_input_details=True,
)
assert response.generated_text == "**Hydrogen**, light and free,\n**He"
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma2_load(flash_gemma2, generate_load, response_snapshot):
responses = await generate_load(
flash_gemma2,
"<start_of_turn>user:\nWrite a poem to help me remember the first 10 elements on the periodic table, giving each element its own line.<end_of_turn>\n<start_of_turn>model:\n",
max_new_tokens=10,
n=4,
)
assert responses[0].generated_text == "**Hydrogen**, light and free,\n**He"
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -25,6 +25,7 @@ mod env_runtime;
struct RawConfig {
max_position_embeddings: Option<usize>,
n_positions: Option<usize>,
model_type: Option<String>,
max_seq_len: Option<usize>,
}
@ -1418,6 +1419,11 @@ fn main() -> Result<(), LauncherError> {
let content = std::fs::read_to_string(filename)?;
let config: RawConfig = serde_json::from_str(&content)?;
if config.model_type == Some("gemma2".to_string()) {
tracing::info!("Forcing flash decoding because of softcap usage");
std::env::set_var("FLASH_DECODING", "1");
}
let config: Config = config.into();
// Quantization usually means you're even more RAM constrained.

View File

@ -1,4 +1,4 @@
flash_att_v2_commit_cuda := v2.5.9.post1
flash_att_v2_commit_cuda := v2.6.1
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
build-flash-attention-v2-cuda:

1135
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -26,7 +26,7 @@ hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97"
tokenizers = "^0.19.1"
huggingface-hub = "^0.23"
transformers = "^4.41"
transformers = "^4.42"
einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true }

View File

@ -1,48 +1,50 @@
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,48 +1,50 @@
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,48 +1,50 @@
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -2,6 +2,7 @@ import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
from typing import Optional
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
@ -43,6 +44,7 @@ def paged_attention(
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
softcap: Optional[float] = None,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
@ -82,6 +84,8 @@ def paged_attention(
# by the current path
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
if softcap is None:
softcap = 0.0
out2 = flash_attn_2_cuda.varlen_fwd(
query,
key_cache,
@ -89,6 +93,7 @@ def paged_attention(
None,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None, # pad_k
None,
block_tables,
None,
@ -100,11 +105,14 @@ def paged_attention(
True, # causal
-1, # Window_left
-1, # Window right
softcap,
False, # return softmax
None, # generator
)
return out2[0]
else:
if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths
from vllm._C import ops
@ -205,6 +213,7 @@ if V2:
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
@ -218,6 +227,7 @@ if V2:
None,
None,
None,
None,
max_s,
max_s,
0.0,
@ -226,6 +236,7 @@ if V2:
causal,
window_size_left,
0,
softcap,
False,
None,
)
@ -241,11 +252,14 @@ else:
max_s,
softmax_scale,
window_size_left=-1,
softcap=None,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
if softcap is not None:
raise NotImplementedError("softcap is only available with flash attn v2")
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:

View File

@ -762,6 +762,8 @@ def get_model(
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# hidden_size / num_attention_heads is wrong in `google/gemma-2-9b-it`
head_size=config_dict["head_dim"],
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))

View File

@ -189,6 +189,7 @@ class FlashGemma2Attention(torch.nn.Module):
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.softcap = config.attn_logit_softcapping
self.query_key_value = load_attention(config, prefix, weights)
@ -246,6 +247,7 @@ class FlashGemma2Attention(torch.nn.Module):
self.softmax_scale,
causal=self.causal,
window_size_left=self.window_size,
softcap=self.softcap,
)
# Decode
else:
@ -259,6 +261,7 @@ class FlashGemma2Attention(torch.nn.Module):
block_tables,
input_lengths,
max_s,
softcap=self.softcap,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -466,6 +469,8 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
config=config,
weights=weights,
)
self.softcap = config.final_logit_softcapping
assert isinstance(self.softcap, float)
def forward(
self,
@ -495,4 +500,9 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
logits /= self.softcap
logits = torch.tanh(logits)
logits *= self.softcap
return logits, speculative_logits