Softcapping for gemma2. (#2273)
* Softcapping for gemma2. * Less clutter. * No access to transformers config, only config_dict here. * 0.0 is the null value in the C++ API.
This commit is contained in:
parent
4844ff790a
commit
6aeb669072
|
@ -0,0 +1,254 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 106,
|
||||||
|
"logprob": -47.25,
|
||||||
|
"text": "<start_of_turn>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1645,
|
||||||
|
"logprob": -18.875,
|
||||||
|
"text": "user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235292,
|
||||||
|
"logprob": -7.15625,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -4.78125,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5559,
|
||||||
|
"logprob": -10.0,
|
||||||
|
"text": "Write"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 476,
|
||||||
|
"logprob": -0.1171875,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 19592,
|
||||||
|
"logprob": -2.46875,
|
||||||
|
"text": " poem"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 577,
|
||||||
|
"logprob": -5.84375,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1707,
|
||||||
|
"logprob": -6.375,
|
||||||
|
"text": " help"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 682,
|
||||||
|
"logprob": -2.125,
|
||||||
|
"text": " me"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5434,
|
||||||
|
"logprob": -1.546875,
|
||||||
|
"text": " remember"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -0.62890625,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1370,
|
||||||
|
"logprob": -6.65625,
|
||||||
|
"text": " first"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -1.84375,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235274,
|
||||||
|
"logprob": -0.45117188,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235276,
|
||||||
|
"logprob": -0.07421875,
|
||||||
|
"text": "0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6635,
|
||||||
|
"logprob": -2.109375,
|
||||||
|
"text": " elements"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 611,
|
||||||
|
"logprob": -0.4140625,
|
||||||
|
"text": " on"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -0.0009536743,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26163,
|
||||||
|
"logprob": -0.033203125,
|
||||||
|
"text": " periodic"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3037,
|
||||||
|
"logprob": -0.0002670288,
|
||||||
|
"text": " table"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235269,
|
||||||
|
"logprob": -4.75,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7385,
|
||||||
|
"logprob": -11.625,
|
||||||
|
"text": " giving"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1853,
|
||||||
|
"logprob": -4.875,
|
||||||
|
"text": " each"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5356,
|
||||||
|
"logprob": -0.38867188,
|
||||||
|
"text": " element"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1277,
|
||||||
|
"logprob": -3.65625,
|
||||||
|
"text": " its"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1997,
|
||||||
|
"logprob": -4.4375,
|
||||||
|
"text": " own"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2017,
|
||||||
|
"logprob": -0.29882812,
|
||||||
|
"text": " line"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235265,
|
||||||
|
"logprob": -0.16699219,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 107,
|
||||||
|
"logprob": -25.625,
|
||||||
|
"text": "<end_of_turn>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -6.75,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 106,
|
||||||
|
"logprob": -39.5,
|
||||||
|
"text": "<start_of_turn>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2516,
|
||||||
|
"logprob": -32.5,
|
||||||
|
"text": "model"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235292,
|
||||||
|
"logprob": -10.125,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -3.421875,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 688,
|
||||||
|
"logprob": -0.546875,
|
||||||
|
"special": false,
|
||||||
|
"text": "**"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 103889,
|
||||||
|
"logprob": -0.49023438,
|
||||||
|
"special": false,
|
||||||
|
"text": "Hydrogen"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 190213,
|
||||||
|
"logprob": -0.48632812,
|
||||||
|
"special": false,
|
||||||
|
"text": "**,"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2611,
|
||||||
|
"logprob": -0.58203125,
|
||||||
|
"special": false,
|
||||||
|
"text": " light"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 578,
|
||||||
|
"logprob": -0.099121094,
|
||||||
|
"special": false,
|
||||||
|
"text": " and"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2223,
|
||||||
|
"logprob": -1.078125,
|
||||||
|
"special": false,
|
||||||
|
"text": " free"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235269,
|
||||||
|
"logprob": -0.025756836,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -0.29101562,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 688,
|
||||||
|
"logprob": -0.0035858154,
|
||||||
|
"special": false,
|
||||||
|
"text": "**"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1949,
|
||||||
|
"logprob": -4.1007996e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": "He"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "**Hydrogen**, light and free,\n**He"
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,46 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_gemma2_handle(launcher):
|
||||||
|
with launcher("google/gemma-2-9b-it", num_shard=2) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_gemma2(flash_gemma2_handle):
|
||||||
|
await flash_gemma2_handle.health(300)
|
||||||
|
return flash_gemma2_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_gemma2(flash_gemma2, response_snapshot):
|
||||||
|
response = await flash_gemma2.generate(
|
||||||
|
"<start_of_turn>user:\nWrite a poem to help me remember the first 10 elements on the periodic table, giving each element its own line.<end_of_turn>\n<start_of_turn>model:\n",
|
||||||
|
max_new_tokens=10,
|
||||||
|
decoder_input_details=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.generated_text == "**Hydrogen**, light and free,\n**He"
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_gemma2_load(flash_gemma2, generate_load, response_snapshot):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_gemma2,
|
||||||
|
"<start_of_turn>user:\nWrite a poem to help me remember the first 10 elements on the periodic table, giving each element its own line.<end_of_turn>\n<start_of_turn>model:\n",
|
||||||
|
max_new_tokens=10,
|
||||||
|
n=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert responses[0].generated_text == "**Hydrogen**, light and free,\n**He"
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
|
@ -25,6 +25,7 @@ mod env_runtime;
|
||||||
struct RawConfig {
|
struct RawConfig {
|
||||||
max_position_embeddings: Option<usize>,
|
max_position_embeddings: Option<usize>,
|
||||||
n_positions: Option<usize>,
|
n_positions: Option<usize>,
|
||||||
|
model_type: Option<String>,
|
||||||
max_seq_len: Option<usize>,
|
max_seq_len: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1418,6 +1419,11 @@ fn main() -> Result<(), LauncherError> {
|
||||||
|
|
||||||
let content = std::fs::read_to_string(filename)?;
|
let content = std::fs::read_to_string(filename)?;
|
||||||
let config: RawConfig = serde_json::from_str(&content)?;
|
let config: RawConfig = serde_json::from_str(&content)?;
|
||||||
|
|
||||||
|
if config.model_type == Some("gemma2".to_string()) {
|
||||||
|
tracing::info!("Forcing flash decoding because of softcap usage");
|
||||||
|
std::env::set_var("FLASH_DECODING", "1");
|
||||||
|
}
|
||||||
let config: Config = config.into();
|
let config: Config = config.into();
|
||||||
|
|
||||||
// Quantization usually means you're even more RAM constrained.
|
// Quantization usually means you're even more RAM constrained.
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
flash_att_v2_commit_cuda := v2.5.9.post1
|
flash_att_v2_commit_cuda := v2.6.1
|
||||||
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
||||||
|
|
||||||
build-flash-attention-v2-cuda:
|
build-flash-attention-v2-cuda:
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -26,7 +26,7 @@ hf-transfer = "^0.1.2"
|
||||||
sentencepiece = "^0.1.97"
|
sentencepiece = "^0.1.97"
|
||||||
tokenizers = "^0.19.1"
|
tokenizers = "^0.19.1"
|
||||||
huggingface-hub = "^0.23"
|
huggingface-hub = "^0.23"
|
||||||
transformers = "^4.41"
|
transformers = "^4.42"
|
||||||
einops = "^0.6.1"
|
einops = "^0.6.1"
|
||||||
texttable = { version = "^1.6.7", optional = true }
|
texttable = { version = "^1.6.7", optional = true }
|
||||||
datasets = { version = "^2.14.0", optional = true }
|
datasets = { version = "^2.14.0", optional = true }
|
||||||
|
|
|
@ -1,48 +1,50 @@
|
||||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
|
|
@ -1,48 +1,50 @@
|
||||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
|
|
@ -1,48 +1,50 @@
|
||||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
|
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
is_sm75 = major == 7 and minor == 5
|
is_sm75 = major == 7 and minor == 5
|
||||||
|
@ -43,6 +44,7 @@ def paged_attention(
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
# Copyright 2023 The vLLM team. All rights
|
# Copyright 2023 The vLLM team. All rights
|
||||||
|
@ -82,6 +84,8 @@ def paged_attention(
|
||||||
# by the current path
|
# by the current path
|
||||||
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
|
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
|
||||||
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
|
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
|
||||||
|
if softcap is None:
|
||||||
|
softcap = 0.0
|
||||||
out2 = flash_attn_2_cuda.varlen_fwd(
|
out2 = flash_attn_2_cuda.varlen_fwd(
|
||||||
query,
|
query,
|
||||||
key_cache,
|
key_cache,
|
||||||
|
@ -89,6 +93,7 @@ def paged_attention(
|
||||||
None,
|
None,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
seqlen.cu_seqlen_k,
|
seqlen.cu_seqlen_k,
|
||||||
|
None, # pad_k
|
||||||
None,
|
None,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
None,
|
||||||
|
@ -100,11 +105,14 @@ def paged_attention(
|
||||||
True, # causal
|
True, # causal
|
||||||
-1, # Window_left
|
-1, # Window_left
|
||||||
-1, # Window right
|
-1, # Window right
|
||||||
|
softcap,
|
||||||
False, # return softmax
|
False, # return softmax
|
||||||
None, # generator
|
None, # generator
|
||||||
)
|
)
|
||||||
return out2[0]
|
return out2[0]
|
||||||
else:
|
else:
|
||||||
|
if softcap is not None:
|
||||||
|
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||||
input_lengths = seqlen.input_lengths
|
input_lengths = seqlen.input_lengths
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
|
|
||||||
|
@ -205,6 +213,7 @@ if V2:
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
softcap=0.0,
|
||||||
):
|
):
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
@ -218,6 +227,7 @@ if V2:
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
|
@ -226,6 +236,7 @@ if V2:
|
||||||
causal,
|
causal,
|
||||||
window_size_left,
|
window_size_left,
|
||||||
0,
|
0,
|
||||||
|
softcap,
|
||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
@ -241,11 +252,14 @@ else:
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
|
softcap=None,
|
||||||
):
|
):
|
||||||
if window_size_left != -1:
|
if window_size_left != -1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"window_size_left is only available with flash attn v2"
|
"window_size_left is only available with flash attn v2"
|
||||||
)
|
)
|
||||||
|
if softcap is not None:
|
||||||
|
raise NotImplementedError("softcap is only available with flash attn v2")
|
||||||
|
|
||||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||||
if k.shape[1] != q.shape[1]:
|
if k.shape[1] != q.shape[1]:
|
||||||
|
|
|
@ -762,6 +762,8 @@ def get_model(
|
||||||
default_dtype=torch.bfloat16,
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
# hidden_size / num_attention_heads is wrong in `google/gemma-2-9b-it`
|
||||||
|
head_size=config_dict["head_dim"],
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
||||||
|
|
|
@ -189,6 +189,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
self.num_key_value_heads = (
|
self.num_key_value_heads = (
|
||||||
config.num_key_value_heads // weights.process_group.size()
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
self.softcap = config.attn_logit_softcapping
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
|
||||||
|
@ -246,6 +247,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
causal=self.causal,
|
causal=self.causal,
|
||||||
window_size_left=self.window_size,
|
window_size_left=self.window_size,
|
||||||
|
softcap=self.softcap,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
|
@ -259,6 +261,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
softcap=self.softcap,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -466,6 +469,8 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
self.softcap = config.final_logit_softcapping
|
||||||
|
assert isinstance(self.softcap, float)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -495,4 +500,9 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
logits /= self.softcap
|
||||||
|
logits = torch.tanh(logits)
|
||||||
|
logits *= self.softcap
|
||||||
|
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
Loading…
Reference in New Issue