Add support for Marlin-quantized models
This change adds support for Marlin-quantized models. Marlin is an FP16xINT4 matmul kernel, which provides good speedups decoding batches of 16-32 tokens. It supports quantized models with symmetric quantization, groupsize -1 or 128, and 4-bit. Tested with: - Llama 2 - Llama 3 - Phi 3
This commit is contained in:
parent
cf0d459aaf
commit
4594e6faba
|
@ -137,6 +137,13 @@ COPY server/Makefile-eetq Makefile
|
|||
# Build specific version of transformers
|
||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
|
||||
|
||||
# Build marlin kernels
|
||||
FROM kernel-builder as marlin-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/Makefile-marlin Makefile
|
||||
# Build specific version of transformers
|
||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-marlin
|
||||
|
||||
# Build Transformers CUDA kernels
|
||||
FROM kernel-builder as custom-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
|
@ -205,6 +212,8 @@ COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-31
|
|||
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
# Copy build artifacts from eetq kernels builder
|
||||
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
# Copy build artifacts from marlin kernels builder
|
||||
COPY --from=marlin-kernels-builder /usr/src/marlin/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# Copy builds artifacts from vllm builder
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
|
|
@ -64,6 +64,7 @@ Options:
|
|||
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
|
||||
- exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1)
|
||||
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
|
||||
- marlin: 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>
|
||||
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
|
||||
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
|
||||
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -12.390625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -11.0625,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.0507812,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.3007812,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 29902,
|
||||
"logprob": -2.0449219,
|
||||
"special": false,
|
||||
"text": "I"
|
||||
},
|
||||
{
|
||||
"id": 505,
|
||||
"logprob": -1.3242188,
|
||||
"special": false,
|
||||
"text": " have"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -0.2076416,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 1243,
|
||||
"logprob": -2.0273438,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.6845703,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 515,
|
||||
"logprob": -1.1748047,
|
||||
"special": false,
|
||||
"text": " from"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -1.0644531,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 1404,
|
||||
"logprob": -1.5224609,
|
||||
"special": false,
|
||||
"text": " user"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nI have a test request from a user"
|
||||
}
|
|
@ -0,0 +1,89 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -12.390625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -11.0625,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 5229,
|
||||
"logprob": -1.2607422,
|
||||
"special": false,
|
||||
"text": " failed"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 6527,
|
||||
"logprob": -0.11450195,
|
||||
"special": false,
|
||||
"text": " Could"
|
||||
},
|
||||
{
|
||||
"id": 451,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " not"
|
||||
},
|
||||
{
|
||||
"id": 4511,
|
||||
"logprob": -0.2286377,
|
||||
"special": false,
|
||||
"text": " connect"
|
||||
},
|
||||
{
|
||||
"id": 304,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " to"
|
||||
},
|
||||
{
|
||||
"id": 1923,
|
||||
"logprob": -1.2568359,
|
||||
"special": false,
|
||||
"text": " server"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.15905762,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 29902,
|
||||
"logprob": -0.21618652,
|
||||
"special": false,
|
||||
"text": "I"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request failed: Could not connect to server\n\nI"
|
||||
}
|
|
@ -0,0 +1,358 @@
|
|||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -12.390625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -11.0625,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.0507812,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.3007812,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 29902,
|
||||
"logprob": -2.0449219,
|
||||
"special": false,
|
||||
"text": "I"
|
||||
},
|
||||
{
|
||||
"id": 505,
|
||||
"logprob": -1.3242188,
|
||||
"special": false,
|
||||
"text": " have"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -0.2076416,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 1243,
|
||||
"logprob": -2.0273438,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.6845703,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 515,
|
||||
"logprob": -1.1748047,
|
||||
"special": false,
|
||||
"text": " from"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -1.0595703,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 1404,
|
||||
"logprob": -1.5224609,
|
||||
"special": false,
|
||||
"text": " user"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nI have a test request from a user"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -12.390625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -11.0625,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.0507812,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.3007812,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 29902,
|
||||
"logprob": -2.0449219,
|
||||
"special": false,
|
||||
"text": "I"
|
||||
},
|
||||
{
|
||||
"id": 505,
|
||||
"logprob": -1.3242188,
|
||||
"special": false,
|
||||
"text": " have"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -0.2076416,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 1243,
|
||||
"logprob": -2.0273438,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.6845703,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 515,
|
||||
"logprob": -1.1748047,
|
||||
"special": false,
|
||||
"text": " from"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -1.0595703,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 1404,
|
||||
"logprob": -1.5224609,
|
||||
"special": false,
|
||||
"text": " user"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nI have a test request from a user"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -12.390625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -11.0625,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.0507812,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.3007812,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 29902,
|
||||
"logprob": -2.0449219,
|
||||
"special": false,
|
||||
"text": "I"
|
||||
},
|
||||
{
|
||||
"id": 505,
|
||||
"logprob": -1.3242188,
|
||||
"special": false,
|
||||
"text": " have"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -0.2076416,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 1243,
|
||||
"logprob": -2.0273438,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.6845703,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 515,
|
||||
"logprob": -1.1748047,
|
||||
"special": false,
|
||||
"text": " from"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -1.0595703,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 1404,
|
||||
"logprob": -1.5224609,
|
||||
"special": false,
|
||||
"text": " user"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nI have a test request from a user"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -12.390625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -11.0625,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.0507812,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.3007812,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 29902,
|
||||
"logprob": -2.0449219,
|
||||
"special": false,
|
||||
"text": "I"
|
||||
},
|
||||
{
|
||||
"id": 505,
|
||||
"logprob": -1.3242188,
|
||||
"special": false,
|
||||
"text": " have"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -0.2076416,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 1243,
|
||||
"logprob": -2.0273438,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.6845703,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 515,
|
||||
"logprob": -1.1748047,
|
||||
"special": false,
|
||||
"text": " from"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -1.0595703,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 1404,
|
||||
"logprob": -1.5224609,
|
||||
"special": false,
|
||||
"text": " user"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nI have a test request from a user"
|
||||
}
|
||||
]
|
|
@ -0,0 +1,63 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_marlin_handle(launcher):
|
||||
with launcher(
|
||||
"neuralmagic/llama-2-7b-chat-marlin", num_shard=2, quantize="marlin"
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_marlin(flash_llama_marlin_handle):
|
||||
await flash_llama_marlin_handle.health(300)
|
||||
return flash_llama_marlin_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
|
||||
response = await flash_llama_marlin.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot):
|
||||
response = await flash_llama_marlin.generate(
|
||||
"Test request",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_marlin_load(
|
||||
flash_llama_marlin, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_llama_marlin, "Test request", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -64,6 +64,8 @@ enum Quantization {
|
|||
/// triton kernel (wider support) when it's not.
|
||||
/// AWQ has faster kernels.
|
||||
Gptq,
|
||||
/// 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>.
|
||||
Marlin,
|
||||
/// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
|
||||
/// but it is known that the model will be much slower to run than the native f16.
|
||||
#[deprecated(
|
||||
|
@ -105,6 +107,9 @@ impl std::fmt::Display for Quantization {
|
|||
Quantization::Gptq => {
|
||||
write!(f, "gptq")
|
||||
}
|
||||
Quantization::Marlin => {
|
||||
write!(f, "marlin")
|
||||
}
|
||||
Quantization::Awq => {
|
||||
write!(f, "awq")
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ include Makefile-flash-att-v2
|
|||
include Makefile-vllm
|
||||
include Makefile-awq
|
||||
include Makefile-eetq
|
||||
include Makefile-marlin
|
||||
include Makefile-selective-scan
|
||||
|
||||
unit-tests:
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
marlin_commit := 2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c
|
||||
|
||||
marlin:
|
||||
# Clone marlin
|
||||
pip install packaging
|
||||
git clone https://github.com/IST-DASLab/marlin.git marlin
|
||||
|
||||
build-marlin: marlin
|
||||
cd marlin && git fetch && git checkout $(marlin_commit)
|
||||
cd marlin && python setup.py build
|
||||
|
||||
install-marlin: build-marlin
|
||||
cd marlin && python setup.py install
|
|
@ -21,6 +21,7 @@ class Quantization(str, Enum):
|
|||
eetq = "eetq"
|
||||
exl2 = "exl2"
|
||||
fp8 = "fp8"
|
||||
marlin = "marlin"
|
||||
|
||||
|
||||
class Dtype(str, Enum):
|
||||
|
|
|
@ -222,6 +222,14 @@ def get_linear(weight, bias, quantize):
|
|||
raise NotImplementedError(
|
||||
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
|
||||
)
|
||||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.marlin import MarlinLinear, MarlinWeight
|
||||
|
||||
if not isinstance(weight, MarlinWeight):
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `marlin` compatible, loader needs to be updated."
|
||||
)
|
||||
linear = MarlinLinear(B=weight.B, s=weight.s, bias=bias)
|
||||
else:
|
||||
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||
return linear
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
import marlin
|
||||
except ImportError:
|
||||
marlin = None
|
||||
|
||||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
has_sm_8_0 = major >= 8
|
||||
except Exception:
|
||||
has_sm_8_0 = False
|
||||
|
||||
MARLIN_TILE_SIZE = 16
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarlinWeight:
|
||||
"""
|
||||
Marlin weights.
|
||||
|
||||
Attributes:
|
||||
B (torch.Tensor): int4-quantized weights packed into int32.
|
||||
s (torch.Tensor): float16 scales.
|
||||
"""
|
||||
|
||||
B: torch.Tensor
|
||||
s: torch.Tensor
|
||||
|
||||
|
||||
class MarlinLinear(nn.Module):
|
||||
def __init__(
|
||||
self, *, B: torch.Tensor, s: torch.Tensor, bias: Optional[torch.Tensor]
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not has_sm_8_0:
|
||||
raise NotImplementedError(
|
||||
"Using quantized marlin models requires CUDA capability 8.0 or later"
|
||||
)
|
||||
|
||||
if marlin is None:
|
||||
raise NotImplementedError(
|
||||
"You do not seem to have marlin installed, either install it (cd server && make install-marlin)"
|
||||
)
|
||||
|
||||
assert B.dtype == torch.int32
|
||||
assert s.dtype == torch.float16
|
||||
|
||||
in_features = B.shape[0] * MARLIN_TILE_SIZE
|
||||
out_features = s.shape[1]
|
||||
assert (
|
||||
in_features % 128 == 0
|
||||
), f"Number of input features ({in_features}) not divisable by 128"
|
||||
assert (
|
||||
out_features % 256 == 0
|
||||
), f"Number of output features ({out_features}) not divisable by 256"
|
||||
|
||||
group_size = -1 if s.shape[0] == 1 else in_features // s.shape[0]
|
||||
assert group_size in {
|
||||
-1,
|
||||
128,
|
||||
}, f"Group size must be -1 or 128, was {group_size}"
|
||||
|
||||
self.register_buffer("B", B)
|
||||
self.register_buffer("s", s)
|
||||
if bias is not None:
|
||||
self.register_buffer("bias", bias)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.workspace = torch.zeros(
|
||||
out_features // 128 * 16, dtype=torch.int, device=B.device
|
||||
)
|
||||
|
||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||
assert marlin is not None
|
||||
C = torch.empty(
|
||||
A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device
|
||||
)
|
||||
marlin.mul(
|
||||
A.view((-1, A.shape[-1])),
|
||||
self.B,
|
||||
C.view((-1, C.shape[-1])),
|
||||
self.s,
|
||||
self.workspace,
|
||||
)
|
||||
|
||||
if self.bias is not None:
|
||||
C += self.bias
|
||||
|
||||
return C
|
|
@ -64,7 +64,7 @@ class TensorParallelHead(SuperLayer):
|
|||
should_gather = False
|
||||
|
||||
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
|
||||
if config.quantize in ["gptq", "awq", "eetq"]:
|
||||
if config.quantize in ["gptq", "awq", "eetq", "marlin"]:
|
||||
quantize = None
|
||||
# See above, exl2 LM head can be quantized or not.
|
||||
elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight):
|
||||
|
|
|
@ -260,7 +260,7 @@ def get_model(
|
|||
) -> Model:
|
||||
global FLASH_ATTENTION
|
||||
if dtype is None:
|
||||
if quantize in ["awq", "exl2", "gptq"]:
|
||||
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
# These quantizers only work with float16 params.
|
||||
dtype = torch.float16
|
||||
else:
|
||||
|
|
|
@ -271,6 +271,11 @@ def _load_gqa(config, prefix: str, weights):
|
|||
groupsize=groupsize,
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
elif config.quantize == "marlin":
|
||||
# NOTE: at the time marlin support was added, the only model that
|
||||
# exists is LnL-AI/dbrx-base-converted-v2-4bit-gptq-marlin(-v2),
|
||||
# but it requires manual concatenation of weight files.
|
||||
raise RuntimeError("dbrx models with marlin quantization are not yet supported")
|
||||
else:
|
||||
qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight")
|
||||
q = qkv_slice[q_start:q_stop]
|
||||
|
|
|
@ -145,7 +145,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||
dim=0,
|
||||
)
|
||||
|
||||
if config.quantize not in ["gptq", "awq"]:
|
||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.head_dim
|
||||
|
|
|
@ -46,6 +46,10 @@ def load_qkv(config, prefix: str, weights, head_size, num_heads):
|
|||
prefix,
|
||||
weights,
|
||||
)
|
||||
elif config.quantize == "marlin":
|
||||
raise RuntimeError(
|
||||
"GPT-2 models with marlin quantization are not yet supported"
|
||||
)
|
||||
else:
|
||||
return _load_qkv(config, prefix, weights, head_size, num_heads)
|
||||
|
||||
|
|
|
@ -139,7 +139,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||
dim=0,
|
||||
)
|
||||
|
||||
if config.quantize not in ["gptq", "awq"]:
|
||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
|
|
|
@ -89,7 +89,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||
dim=0,
|
||||
)
|
||||
|
||||
if config.quantize not in ["gptq", "awq"]:
|
||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
|
|
|
@ -46,7 +46,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||
dim=0,
|
||||
)
|
||||
|
||||
if config.quantize not in ["gptq", "awq"]:
|
||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
|
|
|
@ -29,6 +29,10 @@ def load_multi_mqa(
|
|||
return _load_multi_mqa_gptq(
|
||||
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||
)
|
||||
elif config.quantize == "marlin":
|
||||
raise RuntimeError(
|
||||
"santacoder models with marlin quantization are not yet supported"
|
||||
)
|
||||
else:
|
||||
return _load_multi_mqa(
|
||||
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||
|
|
|
@ -58,7 +58,7 @@ class FlashGPT2(FlashCausalLM):
|
|||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
prefix = ""
|
||||
|
|
|
@ -202,6 +202,12 @@ class Weights:
|
|||
groupsize=groupsize,
|
||||
use_exllama=False,
|
||||
)
|
||||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.marlin import MarlinWeight
|
||||
|
||||
B = self._get_qweight(f"{prefix}.B", blocks)
|
||||
s = self._get_qweight(f"{prefix}.s", blocks)
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
else:
|
||||
slice_ = self._get_slice(f"{prefix}.weight")
|
||||
total_size = slice_.get_shape()[0]
|
||||
|
@ -316,9 +322,25 @@ class Weights:
|
|||
groupsize=groupsize,
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.marlin import MarlinWeight
|
||||
|
||||
try:
|
||||
B = torch.cat(
|
||||
[self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
s = torch.cat([self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1)
|
||||
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
weight = torch.cat(w, dim=dim)
|
||||
|
||||
return weight
|
||||
|
||||
def get_tensor_shard(self, var, dim):
|
||||
|
@ -481,6 +503,19 @@ class Weights:
|
|||
groupsize=groupsize,
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.marlin import MarlinWeight
|
||||
|
||||
try:
|
||||
B = self.get_sharded(f"{prefix}.B", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
s = self.get_sharded(f"{prefix}.s", dim=0)
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
|
||||
else:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
|
Loading…
Reference in New Issue