Softcapping for gemma2. (#2273)
* Softcapping for gemma2. * Less clutter. * No access to transformers config, only config_dict here. * 0.0 is the null value in the C++ API.
This commit is contained in:
parent
4844ff790a
commit
6aeb669072
|
@ -0,0 +1,254 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"logprob": null,
|
||||
"text": "<bos>"
|
||||
},
|
||||
{
|
||||
"id": 106,
|
||||
"logprob": -47.25,
|
||||
"text": "<start_of_turn>"
|
||||
},
|
||||
{
|
||||
"id": 1645,
|
||||
"logprob": -18.875,
|
||||
"text": "user"
|
||||
},
|
||||
{
|
||||
"id": 235292,
|
||||
"logprob": -7.15625,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 108,
|
||||
"logprob": -4.78125,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 5559,
|
||||
"logprob": -10.0,
|
||||
"text": "Write"
|
||||
},
|
||||
{
|
||||
"id": 476,
|
||||
"logprob": -0.1171875,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 19592,
|
||||
"logprob": -2.46875,
|
||||
"text": " poem"
|
||||
},
|
||||
{
|
||||
"id": 577,
|
||||
"logprob": -5.84375,
|
||||
"text": " to"
|
||||
},
|
||||
{
|
||||
"id": 1707,
|
||||
"logprob": -6.375,
|
||||
"text": " help"
|
||||
},
|
||||
{
|
||||
"id": 682,
|
||||
"logprob": -2.125,
|
||||
"text": " me"
|
||||
},
|
||||
{
|
||||
"id": 5434,
|
||||
"logprob": -1.546875,
|
||||
"text": " remember"
|
||||
},
|
||||
{
|
||||
"id": 573,
|
||||
"logprob": -0.62890625,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 1370,
|
||||
"logprob": -6.65625,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 235248,
|
||||
"logprob": -1.84375,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 235274,
|
||||
"logprob": -0.45117188,
|
||||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 235276,
|
||||
"logprob": -0.07421875,
|
||||
"text": "0"
|
||||
},
|
||||
{
|
||||
"id": 6635,
|
||||
"logprob": -2.109375,
|
||||
"text": " elements"
|
||||
},
|
||||
{
|
||||
"id": 611,
|
||||
"logprob": -0.4140625,
|
||||
"text": " on"
|
||||
},
|
||||
{
|
||||
"id": 573,
|
||||
"logprob": -0.0009536743,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 26163,
|
||||
"logprob": -0.033203125,
|
||||
"text": " periodic"
|
||||
},
|
||||
{
|
||||
"id": 3037,
|
||||
"logprob": -0.0002670288,
|
||||
"text": " table"
|
||||
},
|
||||
{
|
||||
"id": 235269,
|
||||
"logprob": -4.75,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 7385,
|
||||
"logprob": -11.625,
|
||||
"text": " giving"
|
||||
},
|
||||
{
|
||||
"id": 1853,
|
||||
"logprob": -4.875,
|
||||
"text": " each"
|
||||
},
|
||||
{
|
||||
"id": 5356,
|
||||
"logprob": -0.38867188,
|
||||
"text": " element"
|
||||
},
|
||||
{
|
||||
"id": 1277,
|
||||
"logprob": -3.65625,
|
||||
"text": " its"
|
||||
},
|
||||
{
|
||||
"id": 1997,
|
||||
"logprob": -4.4375,
|
||||
"text": " own"
|
||||
},
|
||||
{
|
||||
"id": 2017,
|
||||
"logprob": -0.29882812,
|
||||
"text": " line"
|
||||
},
|
||||
{
|
||||
"id": 235265,
|
||||
"logprob": -0.16699219,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 107,
|
||||
"logprob": -25.625,
|
||||
"text": "<end_of_turn>"
|
||||
},
|
||||
{
|
||||
"id": 108,
|
||||
"logprob": -6.75,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 106,
|
||||
"logprob": -39.5,
|
||||
"text": "<start_of_turn>"
|
||||
},
|
||||
{
|
||||
"id": 2516,
|
||||
"logprob": -32.5,
|
||||
"text": "model"
|
||||
},
|
||||
{
|
||||
"id": 235292,
|
||||
"logprob": -10.125,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 108,
|
||||
"logprob": -3.421875,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 688,
|
||||
"logprob": -0.546875,
|
||||
"special": false,
|
||||
"text": "**"
|
||||
},
|
||||
{
|
||||
"id": 103889,
|
||||
"logprob": -0.49023438,
|
||||
"special": false,
|
||||
"text": "Hydrogen"
|
||||
},
|
||||
{
|
||||
"id": 190213,
|
||||
"logprob": -0.48632812,
|
||||
"special": false,
|
||||
"text": "**,"
|
||||
},
|
||||
{
|
||||
"id": 2611,
|
||||
"logprob": -0.58203125,
|
||||
"special": false,
|
||||
"text": " light"
|
||||
},
|
||||
{
|
||||
"id": 578,
|
||||
"logprob": -0.099121094,
|
||||
"special": false,
|
||||
"text": " and"
|
||||
},
|
||||
{
|
||||
"id": 2223,
|
||||
"logprob": -1.078125,
|
||||
"special": false,
|
||||
"text": " free"
|
||||
},
|
||||
{
|
||||
"id": 235269,
|
||||
"logprob": -0.025756836,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 108,
|
||||
"logprob": -0.29101562,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 688,
|
||||
"logprob": -0.0035858154,
|
||||
"special": false,
|
||||
"text": "**"
|
||||
},
|
||||
{
|
||||
"id": 1949,
|
||||
"logprob": -4.1007996e-05,
|
||||
"special": false,
|
||||
"text": "He"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "**Hydrogen**, light and free,\n**He"
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,46 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_gemma2_handle(launcher):
|
||||
with launcher("google/gemma-2-9b-it", num_shard=2) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_gemma2(flash_gemma2_handle):
|
||||
await flash_gemma2_handle.health(300)
|
||||
return flash_gemma2_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_gemma2(flash_gemma2, response_snapshot):
|
||||
response = await flash_gemma2.generate(
|
||||
"<start_of_turn>user:\nWrite a poem to help me remember the first 10 elements on the periodic table, giving each element its own line.<end_of_turn>\n<start_of_turn>model:\n",
|
||||
max_new_tokens=10,
|
||||
decoder_input_details=True,
|
||||
)
|
||||
|
||||
assert response.generated_text == "**Hydrogen**, light and free,\n**He"
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_gemma2_load(flash_gemma2, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
flash_gemma2,
|
||||
"<start_of_turn>user:\nWrite a poem to help me remember the first 10 elements on the periodic table, giving each element its own line.<end_of_turn>\n<start_of_turn>model:\n",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
)
|
||||
|
||||
assert responses[0].generated_text == "**Hydrogen**, light and free,\n**He"
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -25,6 +25,7 @@ mod env_runtime;
|
|||
struct RawConfig {
|
||||
max_position_embeddings: Option<usize>,
|
||||
n_positions: Option<usize>,
|
||||
model_type: Option<String>,
|
||||
max_seq_len: Option<usize>,
|
||||
}
|
||||
|
||||
|
@ -1418,6 +1419,11 @@ fn main() -> Result<(), LauncherError> {
|
|||
|
||||
let content = std::fs::read_to_string(filename)?;
|
||||
let config: RawConfig = serde_json::from_str(&content)?;
|
||||
|
||||
if config.model_type == Some("gemma2".to_string()) {
|
||||
tracing::info!("Forcing flash decoding because of softcap usage");
|
||||
std::env::set_var("FLASH_DECODING", "1");
|
||||
}
|
||||
let config: Config = config.into();
|
||||
|
||||
// Quantization usually means you're even more RAM constrained.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
flash_att_v2_commit_cuda := v2.5.9.post1
|
||||
flash_att_v2_commit_cuda := v2.6.1
|
||||
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
||||
|
||||
build-flash-attention-v2-cuda:
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -26,7 +26,7 @@ hf-transfer = "^0.1.2"
|
|||
sentencepiece = "^0.1.97"
|
||||
tokenizers = "^0.19.1"
|
||||
huggingface-hub = "^0.23"
|
||||
transformers = "^4.41"
|
||||
transformers = "^4.42"
|
||||
einops = "^0.6.1"
|
||||
texttable = { version = "^1.6.7", optional = true }
|
||||
datasets = { version = "^2.14.0", optional = true }
|
||||
|
|
|
@ -1,48 +1,50 @@
|
|||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
|
|
@ -1,48 +1,50 @@
|
|||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
|
|
@ -1,48 +1,50 @@
|
|||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
|||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from typing import Optional
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
|
@ -43,6 +44,7 @@ def paged_attention(
|
|||
block_tables: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||
# Copyright 2023 The vLLM team. All rights
|
||||
|
@ -82,6 +84,8 @@ def paged_attention(
|
|||
# by the current path
|
||||
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
|
||||
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
|
||||
if softcap is None:
|
||||
softcap = 0.0
|
||||
out2 = flash_attn_2_cuda.varlen_fwd(
|
||||
query,
|
||||
key_cache,
|
||||
|
@ -89,6 +93,7 @@ def paged_attention(
|
|||
None,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_k,
|
||||
None, # pad_k
|
||||
None,
|
||||
block_tables,
|
||||
None,
|
||||
|
@ -100,11 +105,14 @@ def paged_attention(
|
|||
True, # causal
|
||||
-1, # Window_left
|
||||
-1, # Window right
|
||||
softcap,
|
||||
False, # return softmax
|
||||
None, # generator
|
||||
)
|
||||
return out2[0]
|
||||
else:
|
||||
if softcap is not None:
|
||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||
input_lengths = seqlen.input_lengths
|
||||
from vllm._C import ops
|
||||
|
||||
|
@ -205,6 +213,7 @@ if V2:
|
|||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
softcap=0.0,
|
||||
):
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
|
@ -218,6 +227,7 @@ if V2:
|
|||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
|
@ -226,6 +236,7 @@ if V2:
|
|||
causal,
|
||||
window_size_left,
|
||||
0,
|
||||
softcap,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
@ -241,11 +252,14 @@ else:
|
|||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
softcap=None,
|
||||
):
|
||||
if window_size_left != -1:
|
||||
raise NotImplementedError(
|
||||
"window_size_left is only available with flash attn v2"
|
||||
)
|
||||
if softcap is not None:
|
||||
raise NotImplementedError("softcap is only available with flash attn v2")
|
||||
|
||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||
if k.shape[1] != q.shape[1]:
|
||||
|
|
|
@ -762,6 +762,8 @@ def get_model(
|
|||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
# hidden_size / num_attention_heads is wrong in `google/gemma-2-9b-it`
|
||||
head_size=config_dict["head_dim"],
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
||||
|
|
|
@ -189,6 +189,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
self.softcap = config.attn_logit_softcapping
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
|
@ -246,6 +247,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||
self.softmax_scale,
|
||||
causal=self.causal,
|
||||
window_size_left=self.window_size,
|
||||
softcap=self.softcap,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
|
@ -259,6 +261,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
softcap=self.softcap,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
@ -466,6 +469,8 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
|
|||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
self.softcap = config.final_logit_softcapping
|
||||
assert isinstance(self.softcap, float)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -495,4 +500,9 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
|
|||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
|
||||
logits /= self.softcap
|
||||
logits = torch.tanh(logits)
|
||||
logits *= self.softcap
|
||||
|
||||
return logits, speculative_logits
|
||||
|
|
Loading…
Reference in New Issue