Add support for Marlin-quantized models

This change adds support for Marlin-quantized models. Marlin is an
FP16xINT4 matmul kernel, which provides good speedups decoding batches
of 16-32 tokens. It supports quantized models with symmetric
quantization, groupsize -1 or 128, and 4-bit.

Tested with:

- Llama 2
- Llama 3
- Phi 3
This commit is contained in:
Daniël de Kok 2024-06-05 08:14:40 +00:00 committed by Daniël de Kok
parent cf0d459aaf
commit 4594e6faba
23 changed files with 788 additions and 7 deletions

View File

@ -137,6 +137,13 @@ COPY server/Makefile-eetq Makefile
# Build specific version of transformers # Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
# Build marlin kernels
FROM kernel-builder as marlin-kernels-builder
WORKDIR /usr/src
COPY server/Makefile-marlin Makefile
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-marlin
# Build Transformers CUDA kernels # Build Transformers CUDA kernels
FROM kernel-builder as custom-kernels-builder FROM kernel-builder as custom-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
@ -205,6 +212,8 @@ COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-31
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from eetq kernels builder # Copy build artifacts from eetq kernels builder
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from marlin kernels builder
COPY --from=marlin-kernels-builder /usr/src/marlin/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy builds artifacts from vllm builder # Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

View File

@ -64,6 +64,7 @@ Options:
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git> - eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
- exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1) - exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1)
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels - gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
- marlin: 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16 - bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model - bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -12.390625,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.0625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.0507812,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -2.3007812,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -2.0449219,
"special": false,
"text": "I"
},
{
"id": 505,
"logprob": -1.3242188,
"special": false,
"text": " have"
},
{
"id": 263,
"logprob": -0.2076416,
"special": false,
"text": " a"
},
{
"id": 1243,
"logprob": -2.0273438,
"special": false,
"text": " test"
},
{
"id": 2009,
"logprob": -0.6845703,
"special": false,
"text": " request"
},
{
"id": 515,
"logprob": -1.1748047,
"special": false,
"text": " from"
},
{
"id": 263,
"logprob": -1.0644531,
"special": false,
"text": " a"
},
{
"id": 1404,
"logprob": -1.5224609,
"special": false,
"text": " user"
}
],
"top_tokens": null
},
"generated_text": "\n\nI have a test request from a user"
}

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -12.390625,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.0625,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 5229,
"logprob": -1.2607422,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": 0.0,
"special": false,
"text": ":"
},
{
"id": 6527,
"logprob": -0.11450195,
"special": false,
"text": " Could"
},
{
"id": 451,
"logprob": 0.0,
"special": false,
"text": " not"
},
{
"id": 4511,
"logprob": -0.2286377,
"special": false,
"text": " connect"
},
{
"id": 304,
"logprob": 0.0,
"special": false,
"text": " to"
},
{
"id": 1923,
"logprob": -1.2568359,
"special": false,
"text": " server"
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.15905762,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -0.21618652,
"special": false,
"text": "I"
}
],
"top_tokens": null
},
"generated_text": "Test request failed: Could not connect to server\n\nI"
}

View File

@ -0,0 +1,358 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -12.390625,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.0625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.0507812,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -2.3007812,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -2.0449219,
"special": false,
"text": "I"
},
{
"id": 505,
"logprob": -1.3242188,
"special": false,
"text": " have"
},
{
"id": 263,
"logprob": -0.2076416,
"special": false,
"text": " a"
},
{
"id": 1243,
"logprob": -2.0273438,
"special": false,
"text": " test"
},
{
"id": 2009,
"logprob": -0.6845703,
"special": false,
"text": " request"
},
{
"id": 515,
"logprob": -1.1748047,
"special": false,
"text": " from"
},
{
"id": 263,
"logprob": -1.0595703,
"special": false,
"text": " a"
},
{
"id": 1404,
"logprob": -1.5224609,
"special": false,
"text": " user"
}
],
"top_tokens": null
},
"generated_text": "\n\nI have a test request from a user"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -12.390625,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.0625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.0507812,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -2.3007812,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -2.0449219,
"special": false,
"text": "I"
},
{
"id": 505,
"logprob": -1.3242188,
"special": false,
"text": " have"
},
{
"id": 263,
"logprob": -0.2076416,
"special": false,
"text": " a"
},
{
"id": 1243,
"logprob": -2.0273438,
"special": false,
"text": " test"
},
{
"id": 2009,
"logprob": -0.6845703,
"special": false,
"text": " request"
},
{
"id": 515,
"logprob": -1.1748047,
"special": false,
"text": " from"
},
{
"id": 263,
"logprob": -1.0595703,
"special": false,
"text": " a"
},
{
"id": 1404,
"logprob": -1.5224609,
"special": false,
"text": " user"
}
],
"top_tokens": null
},
"generated_text": "\n\nI have a test request from a user"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -12.390625,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.0625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.0507812,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -2.3007812,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -2.0449219,
"special": false,
"text": "I"
},
{
"id": 505,
"logprob": -1.3242188,
"special": false,
"text": " have"
},
{
"id": 263,
"logprob": -0.2076416,
"special": false,
"text": " a"
},
{
"id": 1243,
"logprob": -2.0273438,
"special": false,
"text": " test"
},
{
"id": 2009,
"logprob": -0.6845703,
"special": false,
"text": " request"
},
{
"id": 515,
"logprob": -1.1748047,
"special": false,
"text": " from"
},
{
"id": 263,
"logprob": -1.0595703,
"special": false,
"text": " a"
},
{
"id": 1404,
"logprob": -1.5224609,
"special": false,
"text": " user"
}
],
"top_tokens": null
},
"generated_text": "\n\nI have a test request from a user"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -12.390625,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.0625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.0507812,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -2.3007812,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -2.0449219,
"special": false,
"text": "I"
},
{
"id": 505,
"logprob": -1.3242188,
"special": false,
"text": " have"
},
{
"id": 263,
"logprob": -0.2076416,
"special": false,
"text": " a"
},
{
"id": 1243,
"logprob": -2.0273438,
"special": false,
"text": " test"
},
{
"id": 2009,
"logprob": -0.6845703,
"special": false,
"text": " request"
},
{
"id": 515,
"logprob": -1.1748047,
"special": false,
"text": " from"
},
{
"id": 263,
"logprob": -1.0595703,
"special": false,
"text": " a"
},
{
"id": 1404,
"logprob": -1.5224609,
"special": false,
"text": " user"
}
],
"top_tokens": null
},
"generated_text": "\n\nI have a test request from a user"
}
]

View File

@ -0,0 +1,63 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_marlin_handle(launcher):
with launcher(
"neuralmagic/llama-2-7b-chat-marlin", num_shard=2, quantize="marlin"
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_marlin(flash_llama_marlin_handle):
await flash_llama_marlin_handle.health(300)
return flash_llama_marlin_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
response = await flash_llama_marlin.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot):
response = await flash_llama_marlin.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin_load(
flash_llama_marlin, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_marlin, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -64,6 +64,8 @@ enum Quantization {
/// triton kernel (wider support) when it's not. /// triton kernel (wider support) when it's not.
/// AWQ has faster kernels. /// AWQ has faster kernels.
Gptq, Gptq,
/// 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>.
Marlin,
/// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
/// but it is known that the model will be much slower to run than the native f16. /// but it is known that the model will be much slower to run than the native f16.
#[deprecated( #[deprecated(
@ -105,6 +107,9 @@ impl std::fmt::Display for Quantization {
Quantization::Gptq => { Quantization::Gptq => {
write!(f, "gptq") write!(f, "gptq")
} }
Quantization::Marlin => {
write!(f, "marlin")
}
Quantization::Awq => { Quantization::Awq => {
write!(f, "awq") write!(f, "awq")
} }

View File

@ -3,6 +3,7 @@ include Makefile-flash-att-v2
include Makefile-vllm include Makefile-vllm
include Makefile-awq include Makefile-awq
include Makefile-eetq include Makefile-eetq
include Makefile-marlin
include Makefile-selective-scan include Makefile-selective-scan
unit-tests: unit-tests:

13
server/Makefile-marlin Normal file
View File

@ -0,0 +1,13 @@
marlin_commit := 2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c
marlin:
# Clone marlin
pip install packaging
git clone https://github.com/IST-DASLab/marlin.git marlin
build-marlin: marlin
cd marlin && git fetch && git checkout $(marlin_commit)
cd marlin && python setup.py build
install-marlin: build-marlin
cd marlin && python setup.py install

View File

@ -21,6 +21,7 @@ class Quantization(str, Enum):
eetq = "eetq" eetq = "eetq"
exl2 = "exl2" exl2 = "exl2"
fp8 = "fp8" fp8 = "fp8"
marlin = "marlin"
class Dtype(str, Enum): class Dtype(str, Enum):

View File

@ -222,6 +222,14 @@ def get_linear(weight, bias, quantize):
raise NotImplementedError( raise NotImplementedError(
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
) )
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinLinear, MarlinWeight
if not isinstance(weight, MarlinWeight):
raise NotImplementedError(
f"The passed weight is not `marlin` compatible, loader needs to be updated."
)
linear = MarlinLinear(B=weight.B, s=weight.s, bias=bias)
else: else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear return linear

View File

@ -0,0 +1,96 @@
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
try:
import marlin
except ImportError:
marlin = None
try:
major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8
except Exception:
has_sm_8_0 = False
MARLIN_TILE_SIZE = 16
@dataclass
class MarlinWeight:
"""
Marlin weights.
Attributes:
B (torch.Tensor): int4-quantized weights packed into int32.
s (torch.Tensor): float16 scales.
"""
B: torch.Tensor
s: torch.Tensor
class MarlinLinear(nn.Module):
def __init__(
self, *, B: torch.Tensor, s: torch.Tensor, bias: Optional[torch.Tensor]
):
super().__init__()
if not has_sm_8_0:
raise NotImplementedError(
"Using quantized marlin models requires CUDA capability 8.0 or later"
)
if marlin is None:
raise NotImplementedError(
"You do not seem to have marlin installed, either install it (cd server && make install-marlin)"
)
assert B.dtype == torch.int32
assert s.dtype == torch.float16
in_features = B.shape[0] * MARLIN_TILE_SIZE
out_features = s.shape[1]
assert (
in_features % 128 == 0
), f"Number of input features ({in_features}) not divisable by 128"
assert (
out_features % 256 == 0
), f"Number of output features ({out_features}) not divisable by 256"
group_size = -1 if s.shape[0] == 1 else in_features // s.shape[0]
assert group_size in {
-1,
128,
}, f"Group size must be -1 or 128, was {group_size}"
self.register_buffer("B", B)
self.register_buffer("s", s)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
self.workspace = torch.zeros(
out_features // 128 * 16, dtype=torch.int, device=B.device
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin is not None
C = torch.empty(
A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device
)
marlin.mul(
A.view((-1, A.shape[-1])),
self.B,
C.view((-1, C.shape[-1])),
self.s,
self.workspace,
)
if self.bias is not None:
C += self.bias
return C

View File

@ -64,7 +64,7 @@ class TensorParallelHead(SuperLayer):
should_gather = False should_gather = False
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings) # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
if config.quantize in ["gptq", "awq", "eetq"]: if config.quantize in ["gptq", "awq", "eetq", "marlin"]:
quantize = None quantize = None
# See above, exl2 LM head can be quantized or not. # See above, exl2 LM head can be quantized or not.
elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight): elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight):

View File

@ -260,7 +260,7 @@ def get_model(
) -> Model: ) -> Model:
global FLASH_ATTENTION global FLASH_ATTENTION
if dtype is None: if dtype is None:
if quantize in ["awq", "exl2", "gptq"]: if quantize in ["awq", "exl2", "gptq", "marlin"]:
# These quantizers only work with float16 params. # These quantizers only work with float16 params.
dtype = torch.float16 dtype = torch.float16
else: else:

View File

@ -271,6 +271,11 @@ def _load_gqa(config, prefix: str, weights):
groupsize=groupsize, groupsize=groupsize,
use_exllama=use_exllama, use_exllama=use_exllama,
) )
elif config.quantize == "marlin":
# NOTE: at the time marlin support was added, the only model that
# exists is LnL-AI/dbrx-base-converted-v2-4bit-gptq-marlin(-v2),
# but it requires manual concatenation of weight files.
raise RuntimeError("dbrx models with marlin quantization are not yet supported")
else: else:
qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight") qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight")
q = qkv_slice[q_start:q_stop] q = qkv_slice[q_start:q_stop]

View File

@ -145,7 +145,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0, dim=0,
) )
if config.quantize not in ["gptq", "awq"]: if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device) weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.head_dim head_size = config.head_dim

View File

@ -46,6 +46,10 @@ def load_qkv(config, prefix: str, weights, head_size, num_heads):
prefix, prefix,
weights, weights,
) )
elif config.quantize == "marlin":
raise RuntimeError(
"GPT-2 models with marlin quantization are not yet supported"
)
else: else:
return _load_qkv(config, prefix, weights, head_size, num_heads) return _load_qkv(config, prefix, weights, head_size, num_heads)

View File

@ -139,7 +139,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0, dim=0,
) )
if config.quantize not in ["gptq", "awq"]: if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device) weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads head_size = config.hidden_size // config.num_attention_heads

View File

@ -89,7 +89,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0, dim=0,
) )
if config.quantize not in ["gptq", "awq"]: if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device) weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads head_size = config.hidden_size // config.num_attention_heads

View File

@ -46,7 +46,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0, dim=0,
) )
if config.quantize not in ["gptq", "awq"]: if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device) weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads head_size = config.hidden_size // config.num_attention_heads

View File

@ -29,6 +29,10 @@ def load_multi_mqa(
return _load_multi_mqa_gptq( return _load_multi_mqa_gptq(
config, prefix, weights, bias, head_size, num_heads, hidden_size config, prefix, weights, bias, head_size, num_heads, hidden_size
) )
elif config.quantize == "marlin":
raise RuntimeError(
"santacoder models with marlin quantization are not yet supported"
)
else: else:
return _load_multi_mqa( return _load_multi_mqa(
config, prefix, weights, bias, head_size, num_heads, hidden_size config, prefix, weights, bias, head_size, num_heads, hidden_size

View File

@ -58,7 +58,7 @@ class FlashGPT2(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
prefix = "" prefix = ""

View File

@ -202,6 +202,12 @@ class Weights:
groupsize=groupsize, groupsize=groupsize,
use_exllama=False, use_exllama=False,
) )
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight
B = self._get_qweight(f"{prefix}.B", blocks)
s = self._get_qweight(f"{prefix}.s", blocks)
weight = MarlinWeight(B=B, s=s)
else: else:
slice_ = self._get_slice(f"{prefix}.weight") slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0] total_size = slice_.get_shape()[0]
@ -316,9 +322,25 @@ class Weights:
groupsize=groupsize, groupsize=groupsize,
use_exllama=use_exllama, use_exllama=use_exllama,
) )
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight
try:
B = torch.cat(
[self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)
s = torch.cat([self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1)
weight = MarlinWeight(B=B, s=s)
else: else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim) weight = torch.cat(w, dim=dim)
return weight return weight
def get_tensor_shard(self, var, dim): def get_tensor_shard(self, var, dim):
@ -481,6 +503,19 @@ class Weights:
groupsize=groupsize, groupsize=groupsize,
use_exllama=use_exllama, use_exllama=use_exllama,
) )
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight
try:
B = self.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
s = self.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
else: else:
weight = self.get_sharded(f"{prefix}.weight", dim=1) weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight return weight