Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights Handling of quantized weights was split between two mechanisms: - For quantized checkpoints, we used the new weight loader infrastructure. - For quantization while loading (EETQ, FP8, bitsandbytes) we instead relied on conditional in `get_linear`. Weight loaders support context managers to selectively load particular layers with different weight loaders, which is useful for models like Idefics2 AWQ, which uses a quantized text model, but unquantized vision and connector models. However, the context manager would be overrided by `get_linear`, which string-checks `quantizer`. Also, the context manager would not work with EETQ, FP8, and bitsandbytes. This change migrates all quantizers to the weight loader infrastructure. This has several benefits: - We can use context managers with all quantizers. - All the implementation details move down to the quantizer layers, `get_linear` does not need to know how to handle quantizer linear layers. - All quantizer weights are strongly typed, we don't pass around raw tensors. - We don't have to pass around the `quantizer` string everywhere. * Exclude non-MLP layers when using FP8 quantization with Llama
This commit is contained in:
parent
1d1b1efa01
commit
ba291dad9f
|
@ -0,0 +1,89 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.0859375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -16.359375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 5229,
|
||||
"logprob": -2.7988281,
|
||||
"special": false,
|
||||
"text": " failed"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": -0.91259766,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 853,
|
||||
"logprob": -2.8496094,
|
||||
"special": false,
|
||||
"text": " Un"
|
||||
},
|
||||
{
|
||||
"id": 23765,
|
||||
"logprob": -1.1894531,
|
||||
"special": false,
|
||||
"text": "supported"
|
||||
},
|
||||
{
|
||||
"id": 4714,
|
||||
"logprob": -1.5917969,
|
||||
"special": false,
|
||||
"text": " browser"
|
||||
},
|
||||
{
|
||||
"id": 29892,
|
||||
"logprob": -0.34765625,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 1873,
|
||||
"logprob": -1.2695312,
|
||||
"special": false,
|
||||
"text": " version"
|
||||
},
|
||||
{
|
||||
"id": 470,
|
||||
"logprob": -0.25170898,
|
||||
"special": false,
|
||||
"text": " or"
|
||||
},
|
||||
{
|
||||
"id": 7481,
|
||||
"logprob": -0.21411133,
|
||||
"special": false,
|
||||
"text": " platform"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.1162109,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " failed: Unsupported browser, version or platform\n"
|
||||
}
|
|
@ -0,0 +1,89 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.0859375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -16.359375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 5229,
|
||||
"logprob": -0.6645508,
|
||||
"special": false,
|
||||
"text": " failed"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 6527,
|
||||
"logprob": -2.2324219,
|
||||
"special": false,
|
||||
"text": " Could"
|
||||
},
|
||||
{
|
||||
"id": 451,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " not"
|
||||
},
|
||||
{
|
||||
"id": 6088,
|
||||
"logprob": -1.6074219,
|
||||
"special": false,
|
||||
"text": " parse"
|
||||
},
|
||||
{
|
||||
"id": 1243,
|
||||
"logprob": -1.6298828,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
},
|
||||
{
|
||||
"id": 1206,
|
||||
"logprob": -0.72558594,
|
||||
"special": false,
|
||||
"text": " case"
|
||||
},
|
||||
{
|
||||
"id": 1024,
|
||||
"logprob": -0.40429688,
|
||||
"special": false,
|
||||
"text": " name"
|
||||
},
|
||||
{
|
||||
"id": 515,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " from"
|
||||
},
|
||||
{
|
||||
"id": 525,
|
||||
"logprob": -1.2519531,
|
||||
"special": false,
|
||||
"text": " '"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request failed: Could not parse test case name from '"
|
||||
}
|
|
@ -0,0 +1,358 @@
|
|||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.0859375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -16.359375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 5229,
|
||||
"logprob": -2.7988281,
|
||||
"special": false,
|
||||
"text": " failed"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": -0.91259766,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 853,
|
||||
"logprob": -2.8496094,
|
||||
"special": false,
|
||||
"text": " Un"
|
||||
},
|
||||
{
|
||||
"id": 23765,
|
||||
"logprob": -1.1894531,
|
||||
"special": false,
|
||||
"text": "supported"
|
||||
},
|
||||
{
|
||||
"id": 4714,
|
||||
"logprob": -1.5917969,
|
||||
"special": false,
|
||||
"text": " browser"
|
||||
},
|
||||
{
|
||||
"id": 29892,
|
||||
"logprob": -0.34765625,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 1873,
|
||||
"logprob": -1.2695312,
|
||||
"special": false,
|
||||
"text": " version"
|
||||
},
|
||||
{
|
||||
"id": 470,
|
||||
"logprob": -0.25170898,
|
||||
"special": false,
|
||||
"text": " or"
|
||||
},
|
||||
{
|
||||
"id": 7481,
|
||||
"logprob": -0.21411133,
|
||||
"special": false,
|
||||
"text": " platform"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.1162109,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " failed: Unsupported browser, version or platform\n"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.0859375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -16.359375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 5229,
|
||||
"logprob": -2.7988281,
|
||||
"special": false,
|
||||
"text": " failed"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": -0.91259766,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 853,
|
||||
"logprob": -2.8496094,
|
||||
"special": false,
|
||||
"text": " Un"
|
||||
},
|
||||
{
|
||||
"id": 23765,
|
||||
"logprob": -1.1894531,
|
||||
"special": false,
|
||||
"text": "supported"
|
||||
},
|
||||
{
|
||||
"id": 4714,
|
||||
"logprob": -1.5917969,
|
||||
"special": false,
|
||||
"text": " browser"
|
||||
},
|
||||
{
|
||||
"id": 29892,
|
||||
"logprob": -0.34765625,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 1873,
|
||||
"logprob": -1.2695312,
|
||||
"special": false,
|
||||
"text": " version"
|
||||
},
|
||||
{
|
||||
"id": 470,
|
||||
"logprob": -0.25170898,
|
||||
"special": false,
|
||||
"text": " or"
|
||||
},
|
||||
{
|
||||
"id": 7481,
|
||||
"logprob": -0.21411133,
|
||||
"special": false,
|
||||
"text": " platform"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.1162109,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " failed: Unsupported browser, version or platform\n"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.0859375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -16.359375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 5229,
|
||||
"logprob": -2.7988281,
|
||||
"special": false,
|
||||
"text": " failed"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": -0.91259766,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 853,
|
||||
"logprob": -2.8496094,
|
||||
"special": false,
|
||||
"text": " Un"
|
||||
},
|
||||
{
|
||||
"id": 23765,
|
||||
"logprob": -1.1894531,
|
||||
"special": false,
|
||||
"text": "supported"
|
||||
},
|
||||
{
|
||||
"id": 4714,
|
||||
"logprob": -1.5917969,
|
||||
"special": false,
|
||||
"text": " browser"
|
||||
},
|
||||
{
|
||||
"id": 29892,
|
||||
"logprob": -0.34765625,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 1873,
|
||||
"logprob": -1.2695312,
|
||||
"special": false,
|
||||
"text": " version"
|
||||
},
|
||||
{
|
||||
"id": 470,
|
||||
"logprob": -0.25170898,
|
||||
"special": false,
|
||||
"text": " or"
|
||||
},
|
||||
{
|
||||
"id": 7481,
|
||||
"logprob": -0.21411133,
|
||||
"special": false,
|
||||
"text": " platform"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.1162109,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " failed: Unsupported browser, version or platform\n"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.0859375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -16.359375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 5229,
|
||||
"logprob": -2.7988281,
|
||||
"special": false,
|
||||
"text": " failed"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": -0.91259766,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 853,
|
||||
"logprob": -2.8496094,
|
||||
"special": false,
|
||||
"text": " Un"
|
||||
},
|
||||
{
|
||||
"id": 23765,
|
||||
"logprob": -1.1894531,
|
||||
"special": false,
|
||||
"text": "supported"
|
||||
},
|
||||
{
|
||||
"id": 4714,
|
||||
"logprob": -1.5917969,
|
||||
"special": false,
|
||||
"text": " browser"
|
||||
},
|
||||
{
|
||||
"id": 29892,
|
||||
"logprob": -0.34765625,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 1873,
|
||||
"logprob": -1.2695312,
|
||||
"special": false,
|
||||
"text": " version"
|
||||
},
|
||||
{
|
||||
"id": 470,
|
||||
"logprob": -0.25170898,
|
||||
"special": false,
|
||||
"text": " or"
|
||||
},
|
||||
{
|
||||
"id": 7481,
|
||||
"logprob": -0.21411133,
|
||||
"special": false,
|
||||
"text": " platform"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.1162109,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " failed: Unsupported browser, version or platform\n"
|
||||
}
|
||||
]
|
|
@ -100,6 +100,8 @@ async def test_flash_llama_completion_many_prompts_stream(
|
|||
chunk = [c.replace("data:", "") for c in chunk]
|
||||
# remove empty strings
|
||||
chunk = [c for c in chunk if c]
|
||||
# remove completion marking chunk
|
||||
chunk = [c for c in chunk if c != " [DONE]"]
|
||||
# parse json
|
||||
chunk = [json.loads(c) for c in chunk]
|
||||
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_marlin24_handle(launcher):
|
||||
with launcher(
|
||||
"nm-testing/Llama-2-7b-pruned2.4-Marlin_24", quantize="marlin"
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_marlin(flash_llama_marlin24_handle):
|
||||
await flash_llama_marlin24_handle.health(300)
|
||||
return flash_llama_marlin24_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
|
||||
response = await flash_llama_marlin.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_marlin24_all_params(flash_llama_marlin, response_snapshot):
|
||||
response = await flash_llama_marlin.generate(
|
||||
"Test request",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_marlin24_load(
|
||||
flash_llama_marlin, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_llama_marlin, "Test request", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -2,6 +2,7 @@ import pytest
|
|||
import torch
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
UnquantizedWeight,
|
||||
Weights,
|
||||
WeightsLoader,
|
||||
)
|
||||
|
@ -363,7 +364,10 @@ class MockWeights(Weights):
|
|||
self.process_group = process_group
|
||||
self.prefix = prefix
|
||||
self.weights_loader = (
|
||||
DefaultWeightsLoader() if weights_loader is None else weights_loader
|
||||
# We don't need to get linear layers, so just wrap raw tensors.
|
||||
DefaultWeightsLoader(lambda x: x)
|
||||
if weights_loader is None
|
||||
else weights_loader
|
||||
)
|
||||
self._handles = {}
|
||||
|
||||
|
@ -632,6 +636,7 @@ def test_get_weights_col_awq(gptq_weights_loader_awq):
|
|||
g_idx=None,
|
||||
bits=8.0,
|
||||
groupsize=2.0,
|
||||
use_awq_kernel=True,
|
||||
use_exllama=False,
|
||||
)
|
||||
|
||||
|
@ -641,6 +646,7 @@ def test_get_weights_col_awq(gptq_weights_loader_awq):
|
|||
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
|
||||
assert w.bits == expected_weight.bits, "bits mismatch"
|
||||
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
|
||||
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
|
||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
|
@ -669,6 +675,7 @@ def test_get_weights_col_gtpq(gptq_weights_loader):
|
|||
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
|
||||
bits=8.0,
|
||||
groupsize=2.0,
|
||||
use_awq_kernel=False,
|
||||
use_exllama=False,
|
||||
)
|
||||
|
||||
|
@ -678,6 +685,7 @@ def test_get_weights_col_gtpq(gptq_weights_loader):
|
|||
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
|
||||
assert w.bits == expected_weight.bits, "bits mismatch"
|
||||
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
|
||||
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
|
||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
|
@ -774,6 +782,7 @@ def test_get_weights_col_packed_awq(gptq_weights_loader_awq):
|
|||
g_idx=None,
|
||||
bits=8.0,
|
||||
groupsize=2.0,
|
||||
use_awq_kernel=True,
|
||||
use_exllama=False,
|
||||
)
|
||||
|
||||
|
@ -783,6 +792,7 @@ def test_get_weights_col_packed_awq(gptq_weights_loader_awq):
|
|||
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
|
||||
assert w.bits == expected_weight.bits, "bits mismatch"
|
||||
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
|
||||
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
|
||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
|
@ -851,6 +861,7 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader):
|
|||
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
|
||||
bits=8.0,
|
||||
groupsize=2.0,
|
||||
use_awq_kernel=False,
|
||||
use_exllama=False,
|
||||
)
|
||||
|
||||
|
@ -860,6 +871,7 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader):
|
|||
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
|
||||
assert w.bits == expected_weight.bits, "bits mismatch"
|
||||
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
|
||||
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
|
||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
|
@ -922,6 +934,7 @@ def test_get_multi_weights_col_awq(gptq_weights_loader_awq):
|
|||
g_idx=None,
|
||||
bits=8.0,
|
||||
groupsize=2.0,
|
||||
use_awq_kernel=True,
|
||||
use_exllama=False,
|
||||
)
|
||||
|
||||
|
@ -931,6 +944,7 @@ def test_get_multi_weights_col_awq(gptq_weights_loader_awq):
|
|||
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
|
||||
assert w.bits == expected_weight.bits, "bits mismatch"
|
||||
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
|
||||
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
|
||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
|
@ -983,6 +997,7 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader):
|
|||
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
|
||||
bits=8.0,
|
||||
groupsize=2.0,
|
||||
use_awq_kernel=False,
|
||||
use_exllama=False,
|
||||
)
|
||||
|
||||
|
@ -992,6 +1007,7 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader):
|
|||
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
|
||||
assert w.bits == expected_weight.bits, "bits mismatch"
|
||||
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
|
||||
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
|
||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
|
@ -1051,6 +1067,7 @@ def test_get_weights_row_awq(gptq_weights_loader_awq):
|
|||
g_idx=None,
|
||||
bits=8.0,
|
||||
groupsize=2.0,
|
||||
use_awq_kernel=True,
|
||||
use_exllama=False,
|
||||
)
|
||||
|
||||
|
@ -1060,6 +1077,7 @@ def test_get_weights_row_awq(gptq_weights_loader_awq):
|
|||
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
|
||||
assert w.bits == expected_weight.bits, "bits mismatch"
|
||||
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
|
||||
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
|
||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
|
@ -1125,6 +1143,7 @@ def test_get_weights_row_gptq(gptq_weights_loader):
|
|||
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
|
||||
bits=8.0,
|
||||
groupsize=2.0,
|
||||
use_awq_kernel=False,
|
||||
use_exllama=False,
|
||||
)
|
||||
|
||||
|
@ -1134,6 +1153,7 @@ def test_get_weights_row_gptq(gptq_weights_loader):
|
|||
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
|
||||
assert w.bits == expected_weight.bits, "bits mismatch"
|
||||
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
|
||||
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
|
||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import torch
|
||||
from loguru import logger
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
from bitsandbytes.nn import Int8Params, Params4bit
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.weights import Weight
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
|
@ -12,6 +15,14 @@ def warn_deprecate_bnb():
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BNBWeight(Weight):
|
||||
weight: torch.Tensor
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0)
|
||||
|
||||
|
||||
class Linear8bitLt(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -70,6 +81,22 @@ class Linear8bitLt(torch.nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class BNBFP4Weight(Weight):
|
||||
weight: torch.Tensor
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
return Linear4bit(self.weight, bias, quant_type="fp4")
|
||||
|
||||
|
||||
@dataclass
|
||||
class BNBNF4Weight(Weight):
|
||||
weight: torch.Tensor
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
return Linear4bit(self.weight, bias, quant_type="nf4")
|
||||
|
||||
|
||||
class Linear4bit(torch.nn.Module):
|
||||
def __init__(self, weight, bias, quant_type):
|
||||
super().__init__()
|
||||
|
|
|
@ -1,5 +1,23 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from EETQ import quant_weights, w8_a16_gemm
|
||||
from text_generation_server.utils.weights import Weight
|
||||
|
||||
|
||||
@dataclass
|
||||
class EETQWeight(Weight):
|
||||
weight: torch.Tensor
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
try:
|
||||
from text_generation_server.layers.eetq import EETQLinear
|
||||
|
||||
return EETQLinear(self.weight, bias)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
||||
)
|
||||
|
||||
|
||||
class EETQLinear(torch.nn.Module):
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import torch
|
||||
from typing import List, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
from text_generation_server.utils.weights import WeightsLoader, Weights
|
||||
import torch
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
|
||||
@dataclass
|
||||
class Exl2Weight:
|
||||
class Exl2Weight(Weight):
|
||||
"""
|
||||
Exllama2 exl2 quantized weights.
|
||||
"""
|
||||
|
@ -25,6 +25,11 @@ class Exl2Weight:
|
|||
def device(self) -> torch.device:
|
||||
return self.q_weight.device
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
from text_generation_server.layers.gptq import ExllamaQuantLinear
|
||||
|
||||
return ExllamaQuantLinear(self, bias)
|
||||
|
||||
|
||||
class Exl2WeightsLoader(WeightsLoader):
|
||||
"""Loader for exl2-quantized weights."""
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import Weight
|
||||
|
||||
|
||||
def get_fp8_linear() -> torch.nn.Module:
|
||||
|
@ -37,6 +38,14 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
|||
return qweight, scale
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fp8Weight(Weight):
|
||||
weight: torch.Tensor
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
return get_fp8_linear()(self.weight, bias)
|
||||
|
||||
|
||||
class Fp8Linear(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -1,24 +1,23 @@
|
|||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
from safetensors import SafetensorError
|
||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import (
|
||||
SYSTEM,
|
||||
)
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQWeight:
|
||||
class GPTQWeight(Weight):
|
||||
qweight: torch.Tensor
|
||||
qzeros: torch.Tensor
|
||||
scales: torch.Tensor
|
||||
g_idx: Optional[torch.Tensor]
|
||||
bits: int
|
||||
groupsize: int
|
||||
use_awq_kernel: bool
|
||||
use_exllama: bool
|
||||
|
||||
def __post_init__(self):
|
||||
|
@ -29,6 +28,50 @@ class GPTQWeight:
|
|||
def device(self) -> torch.device:
|
||||
return self.qweight.device
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
if self.use_awq_kernel:
|
||||
if SYSTEM == "rocm":
|
||||
raise NotImplementedError(
|
||||
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
|
||||
"to use Exllama/GPTQ kernels for AWQ inference."
|
||||
)
|
||||
try:
|
||||
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
|
||||
|
||||
return WQLinear(
|
||||
w_bit=self.bits,
|
||||
group_size=self.groupsize,
|
||||
qweight=self.qweight,
|
||||
qzeros=self.qzeros,
|
||||
scales=self.scales,
|
||||
bias=bias,
|
||||
)
|
||||
except ImportError:
|
||||
raise NotImplementedError(
|
||||
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
|
||||
)
|
||||
elif self.use_exllama:
|
||||
try:
|
||||
from text_generation_server.layers.gptq import ExllamaQuantLinear
|
||||
except ImportError:
|
||||
raise NotImplementedError(
|
||||
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
|
||||
)
|
||||
|
||||
return ExllamaQuantLinear(self, bias)
|
||||
else:
|
||||
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
||||
|
||||
return QuantLinear(
|
||||
self.qweight,
|
||||
self.qzeros,
|
||||
self.scales,
|
||||
self.g_idx,
|
||||
bias,
|
||||
self.bits,
|
||||
self.groupsize,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
|
@ -45,6 +88,8 @@ elif CAN_EXLLAMA:
|
|||
if V2:
|
||||
from text_generation_server.layers.gptq.exllamav2 import (
|
||||
QuantLinear as ExllamaQuantLinear,
|
||||
)
|
||||
from text_generation_server.layers.gptq.exllamav2 import (
|
||||
create_exllama_buffers,
|
||||
set_device,
|
||||
)
|
||||
|
@ -53,6 +98,8 @@ elif CAN_EXLLAMA:
|
|||
else:
|
||||
from text_generation_server.layers.gptq.exllama import (
|
||||
Ex4bitLinear as ExllamaQuantLinear,
|
||||
)
|
||||
from text_generation_server.layers.gptq.exllama import (
|
||||
create_exllama_buffers,
|
||||
set_device,
|
||||
)
|
||||
|
@ -162,6 +209,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
use_awq_kernel=self.quantize == "awq",
|
||||
use_exllama=False,
|
||||
)
|
||||
|
||||
|
@ -255,6 +303,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
use_awq_kernel=self.quantize == "awq",
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
|
||||
|
@ -336,8 +385,8 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||
use_exllama = False
|
||||
|
||||
from text_generation_server.layers.gptq import (
|
||||
HAS_EXLLAMA,
|
||||
CAN_EXLLAMA,
|
||||
HAS_EXLLAMA,
|
||||
GPTQWeight,
|
||||
)
|
||||
|
||||
|
@ -389,6 +438,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
use_awq_kernel=self.quantize == "awq",
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from torch.nn import functional as F
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
try:
|
||||
|
@ -90,167 +91,14 @@ class FastLinearROCm(torch.nn.Module):
|
|||
return F.linear(inp, self.weight, self.bias)
|
||||
|
||||
|
||||
def get_linear(weight, bias, quantize):
|
||||
if quantize is None:
|
||||
def get_linear(weight, bias):
|
||||
# Weights that are loaded through methods that are not
|
||||
# quantization-aware are still bare tensors. We may want
|
||||
# to change this in the future.
|
||||
if isinstance(weight, torch.Tensor):
|
||||
if SYSTEM == "rocm":
|
||||
linear = FastLinearROCm(weight, bias)
|
||||
return FastLinearROCm(weight, bias)
|
||||
else:
|
||||
linear = FastLinear(weight, bias)
|
||||
elif quantize == "eetq":
|
||||
try:
|
||||
from text_generation_server.layers.eetq import EETQLinear
|
||||
return FastLinear(weight, bias)
|
||||
|
||||
linear = EETQLinear(weight, bias)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
||||
)
|
||||
elif quantize == "fp8":
|
||||
from text_generation_server.layers.fp8 import get_fp8_linear
|
||||
|
||||
linear = get_fp8_linear()(weight, bias)
|
||||
elif quantize == "bitsandbytes":
|
||||
try:
|
||||
from text_generation_server.layers.bnb import (
|
||||
warn_deprecate_bnb,
|
||||
Linear8bitLt,
|
||||
)
|
||||
except ImportError:
|
||||
raise NotImplementedError(
|
||||
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
|
||||
)
|
||||
warn_deprecate_bnb()
|
||||
linear = Linear8bitLt(
|
||||
weight,
|
||||
bias,
|
||||
has_fp16_weights=False,
|
||||
threshold=6.0,
|
||||
)
|
||||
if bias is not None:
|
||||
linear.bias = nn.Parameter(bias)
|
||||
elif quantize == "bitsandbytes-fp4":
|
||||
try:
|
||||
from text_generation_server.layers.bnb import Linear4bit
|
||||
except ImportError:
|
||||
raise NotImplementedError(
|
||||
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
|
||||
)
|
||||
linear = Linear4bit(
|
||||
weight,
|
||||
bias,
|
||||
quant_type="fp4",
|
||||
)
|
||||
elif quantize == "bitsandbytes-nf4":
|
||||
try:
|
||||
from text_generation_server.layers.bnb import Linear4bit
|
||||
except ImportError:
|
||||
raise NotImplementedError(
|
||||
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
|
||||
)
|
||||
linear = Linear4bit(
|
||||
weight,
|
||||
bias,
|
||||
quant_type="nf4",
|
||||
)
|
||||
elif quantize == "exl2":
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
|
||||
if not isinstance(weight, Exl2Weight):
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `exl2` compatible, loader needs to be updated."
|
||||
)
|
||||
|
||||
from text_generation_server.layers.gptq import ExllamaQuantLinear
|
||||
|
||||
linear = ExllamaQuantLinear(weight, bias)
|
||||
|
||||
elif quantize == "gptq":
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlinLinear,
|
||||
GPTQMarlinWeight,
|
||||
)
|
||||
|
||||
if isinstance(weight, GPTQMarlinWeight):
|
||||
linear = GPTQMarlinLinear(
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
)
|
||||
elif isinstance(weight, GPTQWeight):
|
||||
if weight.use_exllama:
|
||||
try:
|
||||
from text_generation_server.layers.gptq import (
|
||||
ExllamaQuantLinear,
|
||||
)
|
||||
except ImportError:
|
||||
raise NotImplementedError(
|
||||
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
|
||||
)
|
||||
|
||||
linear = ExllamaQuantLinear(weight, bias)
|
||||
else:
|
||||
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
||||
|
||||
linear = QuantLinear(
|
||||
weight.qweight,
|
||||
weight.qzeros,
|
||||
weight.scales,
|
||||
weight.g_idx,
|
||||
bias,
|
||||
weight.bits,
|
||||
weight.groupsize,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||
)
|
||||
|
||||
elif quantize == "awq":
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
|
||||
if not isinstance(weight, GPTQWeight):
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
||||
)
|
||||
if SYSTEM == "rocm":
|
||||
raise NotImplementedError(
|
||||
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
|
||||
"to use Exllama/GPTQ kernels for AWQ inference."
|
||||
)
|
||||
try:
|
||||
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
|
||||
|
||||
linear = WQLinear(
|
||||
w_bit=weight.bits,
|
||||
group_size=weight.groupsize,
|
||||
qweight=weight.qweight,
|
||||
qzeros=weight.qzeros,
|
||||
scales=weight.scales,
|
||||
bias=bias,
|
||||
)
|
||||
except ImportError:
|
||||
raise NotImplementedError(
|
||||
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
|
||||
)
|
||||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlin24Linear,
|
||||
GPTQMarlin24Weight,
|
||||
MarlinLinear,
|
||||
MarlinWeight,
|
||||
)
|
||||
|
||||
if isinstance(weight, GPTQMarlin24Weight):
|
||||
linear = GPTQMarlin24Linear(
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
)
|
||||
elif isinstance(weight, MarlinWeight):
|
||||
linear = MarlinLinear(weight=weight, bias=bias)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `marlin` compatible, loader needs to be updated."
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||
return linear
|
||||
return weight.get_linear(bias)
|
||||
|
|
|
@ -7,7 +7,7 @@ from loguru import logger
|
|||
from text_generation_server.layers.fp8 import fp8_quantize
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
try:
|
||||
import marlin_kernels
|
||||
|
@ -63,8 +63,7 @@ class MarlinWeightsLoader(WeightsLoader):
|
|||
return weight
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
if self.is_marlin_24:
|
||||
try:
|
||||
B = torch.cat(
|
||||
[weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
|
||||
|
@ -101,8 +100,7 @@ class MarlinWeightsLoader(WeightsLoader):
|
|||
return weight
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
if self.is_marlin_24:
|
||||
try:
|
||||
B = weights.get_sharded(f"{prefix}.B_24", dim=0)
|
||||
except RuntimeError:
|
||||
|
@ -201,7 +199,7 @@ def permute_scales(scales: torch.Tensor):
|
|||
|
||||
|
||||
@dataclass
|
||||
class GPTQMarlinWeight:
|
||||
class GPTQMarlinWeight(Weight):
|
||||
"""
|
||||
Repacked GPTQ Marlin weights.
|
||||
"""
|
||||
|
@ -219,6 +217,12 @@ class GPTQMarlinWeight:
|
|||
assert self.g_idx.dtype == torch.int32
|
||||
assert self.perm.dtype == torch.int32
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
return GPTQMarlinLinear(
|
||||
weight=self,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
|
||||
def repack_gptq_for_marlin(
|
||||
*,
|
||||
|
@ -376,6 +380,12 @@ class GPTQMarlin24Weight:
|
|||
assert self.B_meta.dtype == torch.int16
|
||||
assert self.s.dtype == torch.float16
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
return GPTQMarlin24Linear(
|
||||
weight=self,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
|
||||
class GPTQMarlin24Linear(nn.Module):
|
||||
def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]):
|
||||
|
@ -567,7 +577,7 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor):
|
|||
|
||||
|
||||
@dataclass
|
||||
class MarlinWeight:
|
||||
class MarlinWeight(Weight):
|
||||
"""
|
||||
Marlin weights.
|
||||
|
||||
|
@ -583,6 +593,9 @@ class MarlinWeight:
|
|||
assert self.B.dtype == torch.int32
|
||||
assert self.s.dtype == torch.float16
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
return MarlinLinear(weight=self, bias=bias)
|
||||
|
||||
|
||||
class MarlinLinear(nn.Module):
|
||||
def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]):
|
||||
|
|
|
@ -77,7 +77,7 @@ class TensorParallelHead(SuperLayer):
|
|||
quantize = config.quantize
|
||||
|
||||
return TensorParallelHead(
|
||||
get_linear(weight, bias=None, quantize=quantize),
|
||||
get_linear(weight, bias=None),
|
||||
process_group=weights.process_group,
|
||||
should_gather=should_gather,
|
||||
)
|
||||
|
@ -134,7 +134,7 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||
raise NotImplementedError("packed_gate_up only implemented without bias")
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
linear = get_linear(weight, bias)
|
||||
return cls(linear)
|
||||
|
||||
@classmethod
|
||||
|
@ -157,7 +157,7 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||
raise NotImplementedError("packed_qkv only implemented for baichuan")
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
linear = get_linear(weight, bias)
|
||||
return cls(linear)
|
||||
|
||||
@classmethod
|
||||
|
@ -167,7 +167,7 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
linear = get_linear(weight, bias)
|
||||
return cls(linear)
|
||||
|
||||
@classmethod
|
||||
|
@ -177,7 +177,7 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||
for prefix in prefixes:
|
||||
weight = weights.get_weights_col(prefix)
|
||||
b = weights.get_tensor(f"{prefix}.bias") if bias else None
|
||||
linears.append(get_linear(weight, b, config.quantize))
|
||||
linears.append(get_linear(weight, b))
|
||||
linear = LayerConcat(linears)
|
||||
else:
|
||||
weight = weights.get_multi_weights_col(prefixes, dim=dim)
|
||||
|
@ -186,7 +186,7 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||
bias = torch.cat(b, dim=dim)
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
linear = get_linear(weight, bias)
|
||||
return cls(linear)
|
||||
|
||||
|
||||
|
@ -205,7 +205,7 @@ class TensorParallelRowLinear(SuperLayer):
|
|||
else:
|
||||
bias = None
|
||||
return cls(
|
||||
get_linear(weight, bias, config.quantize),
|
||||
get_linear(weight, bias),
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
|
|
|
@ -186,9 +186,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||
else:
|
||||
bias = None
|
||||
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=bias, quantize=config.quantize)
|
||||
)
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias=bias))
|
||||
|
||||
|
||||
class FlashCohereAttention(torch.nn.Module):
|
||||
|
|
|
@ -247,10 +247,10 @@ def _load_experts_quantized(config, prefix, weights, cls):
|
|||
|
||||
if cls == TensorParallelRowLinear:
|
||||
expert_slice = expert_slice.t().contiguous()
|
||||
linear = get_linear(expert_slice, None, config.quantize)
|
||||
linear = get_linear(expert_slice, None)
|
||||
experts.append(cls(linear, weights.process_group))
|
||||
else:
|
||||
linear = get_linear(expert_slice, None, config.quantize)
|
||||
linear = get_linear(expert_slice, None)
|
||||
experts.append(cls(linear))
|
||||
|
||||
return experts
|
||||
|
|
|
@ -155,9 +155,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=None, quantize=config.quantize)
|
||||
)
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias=None))
|
||||
|
||||
|
||||
class FlashGemma2Attention(torch.nn.Module):
|
||||
|
|
|
@ -155,9 +155,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=None, quantize=config.quantize)
|
||||
)
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias=None))
|
||||
|
||||
|
||||
class FlashGemmaAttention(torch.nn.Module):
|
||||
|
|
|
@ -82,7 +82,7 @@ def _load_qkv_gptq(config, prefix: str, weights):
|
|||
bias = torch.cat(tensors, dim=0)
|
||||
bias = bias.to(device=weights.device)
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias))
|
||||
|
||||
|
||||
def _load_qkv(config, prefix: str, weights, head_size, num_heads):
|
||||
|
@ -129,7 +129,7 @@ def _load_qkv(config, prefix: str, weights, head_size, num_heads):
|
|||
3 * num_heads * head_size
|
||||
], f"{weight.shape} != {[3 * num_heads * head_size]}"
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias))
|
||||
|
||||
|
||||
def load_row(config, prefix: str, weights, bias: bool):
|
||||
|
@ -147,7 +147,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||
bias = None
|
||||
|
||||
return TensorParallelRowLinear(
|
||||
get_linear(weight, bias, config.quantize), process_group=weights.process_group
|
||||
get_linear(weight, bias), process_group=weights.process_group
|
||||
)
|
||||
|
||||
|
||||
|
@ -163,7 +163,7 @@ def load_col(config, prefix: str, weights, bias: bool):
|
|||
else:
|
||||
bias = None
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias))
|
||||
|
||||
|
||||
class FlashGPT2Attention(torch.nn.Module):
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
@ -25,7 +26,6 @@ import torch.distributed
|
|||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
|
@ -42,10 +42,16 @@ from text_generation_server.layers import (
|
|||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.fp8 import Fp8Weight
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
UnquantizedWeight,
|
||||
Weights,
|
||||
)
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
try:
|
||||
|
@ -105,6 +111,19 @@ def load_attention(config, prefix: str, weights, layer_id):
|
|||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def no_fp8(weights: Weights):
|
||||
weights_loader = weights.weights_loader
|
||||
if (
|
||||
isinstance(weights_loader, DefaultWeightsLoader)
|
||||
and weights_loader.weight_class is Fp8Weight
|
||||
):
|
||||
weights_loader = DefaultWeightsLoader(UnquantizedWeight)
|
||||
|
||||
with weights.use_loader(weights_loader):
|
||||
yield
|
||||
|
||||
|
||||
class FlashLlamaAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -330,12 +349,15 @@ class LlamaMLP(nn.Module):
|
|||
class FlashLlamaLayer(nn.Module):
|
||||
def __init__(self, index, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
index=index,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
with no_fp8(weights):
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
index=index,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.mlp = LlamaMLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
|
||||
)
|
||||
|
@ -470,23 +492,27 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=(
|
||||
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
||||
),
|
||||
weights=weights,
|
||||
)
|
||||
with no_fp8(weights):
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=(
|
||||
"model.embed_tokens"
|
||||
if not prefix
|
||||
else f"{prefix}.model.embed_tokens"
|
||||
),
|
||||
weights=weights,
|
||||
)
|
||||
self.model = FlashLlamaModel(prefix, config, weights)
|
||||
if config.tie_word_embeddings:
|
||||
suffix = "model.embed_tokens"
|
||||
else:
|
||||
suffix = "lm_head"
|
||||
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
||||
weights=weights,
|
||||
)
|
||||
with no_fp8(weights):
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -149,9 +149,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=None, quantize=config.quantize)
|
||||
)
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias=None))
|
||||
|
||||
|
||||
def _load_experts(config, prefix: str, mat, weights):
|
||||
|
|
|
@ -56,7 +56,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||
else:
|
||||
bias = None
|
||||
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
linear = get_linear(weight, bias)
|
||||
if config.use_parallel_residual:
|
||||
return linear
|
||||
else:
|
||||
|
@ -81,7 +81,7 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
|||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
|
||||
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
linear = get_linear(weight, bias)
|
||||
if config.use_parallel_residual:
|
||||
return linear
|
||||
else:
|
||||
|
|
|
@ -100,9 +100,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
# this is the same as llama except for Phi uses bias=True
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=True, quantize=config.quantize)
|
||||
)
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias=True))
|
||||
|
||||
|
||||
class FlashPhiAttention(torch.nn.Module):
|
||||
|
|
|
@ -31,7 +31,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||
else:
|
||||
bias = None
|
||||
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
linear = get_linear(weight, bias)
|
||||
if config.parallel_attn:
|
||||
return linear
|
||||
else:
|
||||
|
|
|
@ -105,6 +105,7 @@ def _load_multi_mqa_gptq(
|
|||
g_idx=g_idx,
|
||||
bits=loader.bits,
|
||||
groupsize=loader.groupsize,
|
||||
use_awq_kernel=loader.quantize == "awq",
|
||||
use_exllama=HAS_EXLLAMA,
|
||||
)
|
||||
|
||||
|
@ -121,7 +122,7 @@ def _load_multi_mqa_gptq(
|
|||
bias = torch.cat([q_tensor, kv_tensor], dim=0)
|
||||
bias = bias.to(device=weights.device)
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias))
|
||||
else:
|
||||
raise NotImplementedError("Gptq loading with santacoder is not implemented")
|
||||
|
||||
|
@ -193,7 +194,7 @@ def _load_multi_mqa(
|
|||
assert list(bias.shape) == [
|
||||
(num_heads + 2) * head_size
|
||||
], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias))
|
||||
|
||||
|
||||
def load_col(config, prefix: str, weights, bias: bool):
|
||||
|
@ -206,7 +207,7 @@ def load_col(config, prefix: str, weights, bias: bool):
|
|||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
else:
|
||||
bias = None
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias))
|
||||
|
||||
|
||||
def load_row(config, prefix: str, weights, bias: bool):
|
||||
|
@ -221,7 +222,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||
else:
|
||||
bias = None
|
||||
return TensorParallelRowLinear(
|
||||
get_linear(weight, bias, config.quantize), process_group=weights.process_group
|
||||
get_linear(weight, bias), process_group=weights.process_group
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -149,9 +149,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||
else:
|
||||
bias = None
|
||||
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=bias, quantize=config.quantize)
|
||||
)
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias=bias))
|
||||
|
||||
|
||||
class Starcoder2Attention(torch.nn.Module):
|
||||
|
|
|
@ -34,7 +34,7 @@ from text_generation_server.layers import (
|
|||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
from text_generation_server.utils.weights import DefaultWeightsLoader
|
||||
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
|
@ -698,7 +698,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||
self.dtype = weights.dtype
|
||||
|
||||
# The vision and connector models are not quantized.
|
||||
with weights.use_loader(DefaultWeightsLoader()):
|
||||
with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
|
||||
self.vision_model = Idefics2VisionTransformer(
|
||||
prefix=(
|
||||
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
|
||||
|
@ -707,16 +707,12 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||
weights=weights,
|
||||
)
|
||||
|
||||
quantize = config.quantize
|
||||
try:
|
||||
config.quantize = None
|
||||
self.connector = Idefics2Connector(
|
||||
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
finally:
|
||||
config.quantize = quantize
|
||||
config.quantize = None
|
||||
self.connector = Idefics2Connector(
|
||||
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.image_seq_len = config.perceiver_config.resampler_n_latents
|
||||
|
|
|
@ -75,7 +75,7 @@ def load_col(config, prefix, weights, bias):
|
|||
bias = bias.to(device=weights.device)
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
linear = get_linear(weight, bias)
|
||||
return TensorParallelColumnLinear(linear)
|
||||
|
||||
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
from typing import Optional
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from text_generation_server.utils.weights import DefaultWeightsLoader, WeightsLoader
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
UnquantizedWeight,
|
||||
WeightsLoader,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -104,10 +107,30 @@ def get_loader(
|
|||
quantize=quantize,
|
||||
sym=quantizer_config.sym,
|
||||
)
|
||||
elif quantize == "bitsandbytes":
|
||||
from text_generation_server.layers.bnb import BNBWeight
|
||||
|
||||
return DefaultWeightsLoader(BNBWeight)
|
||||
elif quantize == "bitsandbytes-fp4":
|
||||
from text_generation_server.layers.bnb import BNBFP4Weight
|
||||
|
||||
return DefaultWeightsLoader(BNBFP4Weight)
|
||||
elif quantize == "bitsandbytes-nf4":
|
||||
from text_generation_server.layers.bnb import BNBNF4Weight
|
||||
|
||||
return DefaultWeightsLoader(BNBNF4Weight)
|
||||
elif quantize == "eetq":
|
||||
from text_generation_server.layers.eetq import EETQWeight
|
||||
|
||||
return DefaultWeightsLoader(EETQWeight)
|
||||
elif quantize == "exl2":
|
||||
from text_generation_server.layers.exl2 import Exl2WeightsLoader
|
||||
|
||||
return Exl2WeightsLoader()
|
||||
elif quantize == "fp8":
|
||||
from text_generation_server.layers.fp8 import Fp8Weight
|
||||
|
||||
return DefaultWeightsLoader(Fp8Weight)
|
||||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.marlin import MarlinWeightsLoader
|
||||
|
||||
|
@ -115,5 +138,7 @@ def get_loader(
|
|||
bits=quantizer_config.bits,
|
||||
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
|
||||
)
|
||||
elif quantize is None:
|
||||
return DefaultWeightsLoader(UnquantizedWeight)
|
||||
else:
|
||||
return DefaultWeightsLoader()
|
||||
raise ValueError(f"Unknown quantization method: {quantize}")
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from safetensors import safe_open
|
||||
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
class WeightsLoader(ABC):
|
||||
|
@ -62,7 +66,39 @@ class WeightsLoader(ABC):
|
|||
...
|
||||
|
||||
|
||||
class Weight(ABC):
|
||||
"""Instances of this type implement unquantized/quantized/to-be
|
||||
quantized weights."""
|
||||
|
||||
@abstractmethod
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
"""Create a linear layer from this weight."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnquantizedWeight:
|
||||
weight: torch.Tensor
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
from text_generation_server.layers.linear import FastLinear, FastLinearROCm
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
return FastLinearROCm(self.weight, bias)
|
||||
else:
|
||||
return FastLinear(self.weight, bias)
|
||||
|
||||
|
||||
class DefaultWeightsLoader(WeightsLoader):
|
||||
"""Weight loader that loads (unquantized) Torch tensors."""
|
||||
|
||||
def __init__(self, weight_class):
|
||||
"""Create a loader. Weights will be wrapped using the given `weights_class`,
|
||||
normally this will be `UnquantizedWeight`, but a quantizer-specific class
|
||||
such as `Fp8Weight` can be used to quantize the weights during loading.
|
||||
"""
|
||||
self.weight_class = weight_class
|
||||
|
||||
"""
|
||||
Loader that uses tensors as-is with the exception of applying sharding
|
||||
and/or concatenation.
|
||||
|
@ -74,16 +110,21 @@ class DefaultWeightsLoader(WeightsLoader):
|
|||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
return weights.get_packed_sharded(
|
||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||
|
||||
return self.weight_class(
|
||||
weights.get_packed_sharded(
|
||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||
),
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
return torch.cat(w, dim=dim)
|
||||
return self.weight_class(torch.cat(w, dim=dim))
|
||||
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
return weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return self.weight_class(
|
||||
weights.get_sharded(f"{prefix}.weight", dim=1),
|
||||
)
|
||||
|
||||
|
||||
class Weights:
|
||||
|
|
Loading…
Reference in New Issue