Use GPTQ-Marlin for supported GPTQ configurations (#2111)
GPTQ-Marlin is currently the best-performing kernel for GPTQ models. So let's use it by default if the kernels are installed, the GPU supports it, and the kernels support the configuration. For models generated by `text-generation-server quantize`, use `sym=False`. This subcommand symmetric quantization since the beginning and incorrectly reporting the model to be symmetric will use GPTQ-Marlin (which does not support asymmetric quantization).
This commit is contained in:
parent
0d97a93c1e
commit
2ce8019480
|
@ -1,84 +0,0 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -11.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -2.5742188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -1.6230469,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 3270,
|
||||
"logprob": -2.046875,
|
||||
"special": false,
|
||||
"text": " \"\"\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -0.015281677,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 422,
|
||||
"logprob": -2.1425781,
|
||||
"special": false,
|
||||
"text": " if"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -0.9238281,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13204,
|
||||
"logprob": -0.076660156,
|
||||
"special": false,
|
||||
"text": ".method"
|
||||
},
|
||||
{
|
||||
"id": 624,
|
||||
"logprob": -0.021987915,
|
||||
"special": false,
|
||||
"text": " =="
|
||||
},
|
||||
{
|
||||
"id": 364,
|
||||
"logprob": -0.39208984,
|
||||
"special": false,
|
||||
"text": " '"
|
||||
},
|
||||
{
|
||||
"id": 3019,
|
||||
"logprob": -0.10821533,
|
||||
"special": false,
|
||||
"text": "POST"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||
}
|
|
@ -1,84 +0,0 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -11.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.2539062,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 578,
|
||||
"logprob": -0.15563965,
|
||||
"special": false,
|
||||
"text": " The"
|
||||
},
|
||||
{
|
||||
"id": 3622,
|
||||
"logprob": -0.8203125,
|
||||
"special": false,
|
||||
"text": " server"
|
||||
},
|
||||
{
|
||||
"id": 706,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " has"
|
||||
},
|
||||
{
|
||||
"id": 539,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " not"
|
||||
},
|
||||
{
|
||||
"id": 3686,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " yet"
|
||||
},
|
||||
{
|
||||
"id": 3288,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " sent"
|
||||
},
|
||||
{
|
||||
"id": 904,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " any"
|
||||
},
|
||||
{
|
||||
"id": 828,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " data"
|
||||
},
|
||||
{
|
||||
"id": 382,
|
||||
"logprob": -1.5517578,
|
||||
"special": false,
|
||||
"text": ".\n\n"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request. The server has not yet sent any data.\n\n"
|
||||
}
|
|
@ -1,338 +0,0 @@
|
|||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -11.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -2.5742188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -1.6220703,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 3270,
|
||||
"logprob": -2.0410156,
|
||||
"special": false,
|
||||
"text": " \"\"\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -0.015281677,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 422,
|
||||
"logprob": -2.1445312,
|
||||
"special": false,
|
||||
"text": " if"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -0.92333984,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13204,
|
||||
"logprob": -0.07672119,
|
||||
"special": false,
|
||||
"text": ".method"
|
||||
},
|
||||
{
|
||||
"id": 624,
|
||||
"logprob": -0.021987915,
|
||||
"special": false,
|
||||
"text": " =="
|
||||
},
|
||||
{
|
||||
"id": 364,
|
||||
"logprob": -0.39208984,
|
||||
"special": false,
|
||||
"text": " '"
|
||||
},
|
||||
{
|
||||
"id": 3019,
|
||||
"logprob": -0.10638428,
|
||||
"special": false,
|
||||
"text": "POST"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -11.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -2.5742188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -1.6220703,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 3270,
|
||||
"logprob": -2.0410156,
|
||||
"special": false,
|
||||
"text": " \"\"\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -0.015281677,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 422,
|
||||
"logprob": -2.1445312,
|
||||
"special": false,
|
||||
"text": " if"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -0.92333984,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13204,
|
||||
"logprob": -0.07672119,
|
||||
"special": false,
|
||||
"text": ".method"
|
||||
},
|
||||
{
|
||||
"id": 624,
|
||||
"logprob": -0.021987915,
|
||||
"special": false,
|
||||
"text": " =="
|
||||
},
|
||||
{
|
||||
"id": 364,
|
||||
"logprob": -0.39208984,
|
||||
"special": false,
|
||||
"text": " '"
|
||||
},
|
||||
{
|
||||
"id": 3019,
|
||||
"logprob": -0.10638428,
|
||||
"special": false,
|
||||
"text": "POST"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -11.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -2.5742188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -1.6220703,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 3270,
|
||||
"logprob": -2.0410156,
|
||||
"special": false,
|
||||
"text": " \"\"\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -0.015281677,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 422,
|
||||
"logprob": -2.1445312,
|
||||
"special": false,
|
||||
"text": " if"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -0.92333984,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13204,
|
||||
"logprob": -0.07672119,
|
||||
"special": false,
|
||||
"text": ".method"
|
||||
},
|
||||
{
|
||||
"id": 624,
|
||||
"logprob": -0.021987915,
|
||||
"special": false,
|
||||
"text": " =="
|
||||
},
|
||||
{
|
||||
"id": 364,
|
||||
"logprob": -0.39208984,
|
||||
"special": false,
|
||||
"text": " '"
|
||||
},
|
||||
{
|
||||
"id": 3019,
|
||||
"logprob": -0.10638428,
|
||||
"special": false,
|
||||
"text": "POST"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -11.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -2.5742188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -1.6220703,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 3270,
|
||||
"logprob": -2.0410156,
|
||||
"special": false,
|
||||
"text": " \"\"\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -0.015281677,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 422,
|
||||
"logprob": -2.1445312,
|
||||
"special": false,
|
||||
"text": " if"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -0.92333984,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13204,
|
||||
"logprob": -0.07672119,
|
||||
"special": false,
|
||||
"text": ".method"
|
||||
},
|
||||
{
|
||||
"id": 624,
|
||||
"logprob": -0.021987915,
|
||||
"special": false,
|
||||
"text": " =="
|
||||
},
|
||||
{
|
||||
"id": 364,
|
||||
"logprob": -0.39208984,
|
||||
"special": false,
|
||||
"text": " '"
|
||||
},
|
||||
{
|
||||
"id": 3019,
|
||||
"logprob": -0.10638428,
|
||||
"special": false,
|
||||
"text": "POST"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||
}
|
||||
]
|
|
@ -1,68 +0,0 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_gptq_marlin_handle(launcher):
|
||||
with launcher(
|
||||
"astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="marlin"
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle):
|
||||
await flash_llama_gptq_marlin_handle.health(300)
|
||||
return flash_llama_gptq_marlin_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
|
||||
response = await flash_llama_gptq_marlin.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_gptq_marlin_all_params(
|
||||
flash_llama_gptq_marlin, response_snapshot
|
||||
):
|
||||
response = await flash_llama_gptq_marlin.generate(
|
||||
"Test request",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_gptq_marlin_load(
|
||||
flash_llama_gptq_marlin, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_llama_gptq_marlin, "Test request", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -7,6 +7,16 @@ from text_generation_server.utils.import_utils import (
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQParams:
|
||||
bits: int
|
||||
checkpoint_format: Optional[str]
|
||||
groupsize: int
|
||||
desc_act: bool
|
||||
quant_method: str
|
||||
sym: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQWeight:
|
||||
qweight: torch.Tensor
|
||||
|
|
|
@ -166,12 +166,17 @@ def get_linear(weight, bias, quantize):
|
|||
|
||||
elif quantize == "gptq":
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
|
||||
if not isinstance(weight, GPTQWeight):
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlinLinear,
|
||||
GPTQMarlinWeight,
|
||||
)
|
||||
|
||||
if isinstance(weight, GPTQMarlinWeight):
|
||||
linear = GPTQMarlinLinear(
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
)
|
||||
elif isinstance(weight, GPTQWeight):
|
||||
if weight.use_exllama:
|
||||
try:
|
||||
from text_generation_server.layers.gptq import (
|
||||
|
@ -195,6 +200,11 @@ def get_linear(weight, bias, quantize):
|
|||
weight.bits,
|
||||
weight.groupsize,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||
)
|
||||
|
||||
elif quantize == "awq":
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
|
||||
|
@ -226,18 +236,11 @@ def get_linear(weight, bias, quantize):
|
|||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlin24Linear,
|
||||
GPTQMarlin24Weight,
|
||||
GPTQMarlinLinear,
|
||||
GPTQMarlinWeight,
|
||||
MarlinLinear,
|
||||
MarlinWeight,
|
||||
)
|
||||
|
||||
if isinstance(weight, GPTQMarlinWeight):
|
||||
linear = GPTQMarlinLinear(
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
)
|
||||
elif isinstance(weight, GPTQMarlin24Weight):
|
||||
if isinstance(weight, GPTQMarlin24Weight):
|
||||
linear = GPTQMarlin24Linear(
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
|
|
|
@ -3,6 +3,8 @@ from typing import List, Optional, Tuple
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.layers.gptq import GPTQParams
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
try:
|
||||
|
@ -22,6 +24,19 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
|
|||
MARLIN_TILE_SIZE = 16
|
||||
|
||||
|
||||
def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool:
|
||||
return (
|
||||
SYSTEM == "cuda"
|
||||
and marlin_kernels is not None
|
||||
and has_sm_8_0
|
||||
and quantize == "gptq"
|
||||
and gptq_params.quant_method == "gptq"
|
||||
and gptq_params.bits in GPTQ_MARLIN_BITS
|
||||
and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES
|
||||
and gptq_params.sym
|
||||
)
|
||||
|
||||
|
||||
def _check_marlin_kernels():
|
||||
if not (SYSTEM == "cuda" and has_sm_8_0):
|
||||
raise NotImplementedError(
|
||||
|
|
|
@ -1,25 +1,15 @@
|
|||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
from safetensors import safe_open, SafetensorError
|
||||
import torch
|
||||
from loguru import logger
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
from text_generation_server.layers.gptq import GPTQParams
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
|
||||
@dataclass
|
||||
class _GPTQParams:
|
||||
bits: int
|
||||
checkpoint_format: Optional[str]
|
||||
groupsize: int
|
||||
desc_act: bool
|
||||
quant_method: str
|
||||
sym: bool
|
||||
|
||||
|
||||
class Weights:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -212,6 +202,10 @@ class Weights:
|
|||
"""
|
||||
if quantize in ["gptq", "awq"]:
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
try:
|
||||
qweight = self.get_packed_sharded(
|
||||
|
@ -221,17 +215,28 @@ class Weights:
|
|||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
|
||||
qzeros = self.get_packed_sharded(
|
||||
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
scales = self.get_packed_sharded(
|
||||
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
scales = scales.to(dtype=self.dtype)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
if can_use_gptq_marlin(gptq_params, quantize):
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
desc_act=gptq_params.desc_act,
|
||||
groupsize=gptq_params.groupsize,
|
||||
sym=gptq_params.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
qzeros = self.get_packed_sharded(
|
||||
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
if quantize == "gptq" and gptq_params.quant_method == "gptq":
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
elif quantize == "gptq" and gptq_params.quant_method == "awq":
|
||||
|
@ -269,7 +274,6 @@ class Weights:
|
|||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
quant_method = getattr(self, "quant_method", "marlin")
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
B = self.get_packed_sharded(
|
||||
|
@ -286,31 +290,6 @@ class Weights:
|
|||
weight = GPTQMarlin24Weight(
|
||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||
)
|
||||
elif quant_method == "gptq":
|
||||
gptq_params = self._get_gptq_params()
|
||||
try:
|
||||
qweight = self.get_packed_sharded(
|
||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
scales = self.get_packed_sharded(
|
||||
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
weight = repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
desc_act=gptq_params.desc_act,
|
||||
groupsize=gptq_params.groupsize,
|
||||
sym=gptq_params.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
else:
|
||||
B = self.get_packed_sharded(
|
||||
f"{prefix}.B", dim=1, block_sizes=block_sizes
|
||||
|
@ -356,6 +335,10 @@ class Weights:
|
|||
raise ValueError("get_multi_weights_col is not supported for exl2")
|
||||
elif quantize in ["gptq", "awq"]:
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
|
@ -366,14 +349,31 @@ class Weights:
|
|||
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
qzeros = torch.cat(
|
||||
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
scales = torch.cat(
|
||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
if can_use_gptq_marlin(gptq_params, quantize):
|
||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
desc_act=gptq_params.desc_act,
|
||||
groupsize=gptq_params.groupsize,
|
||||
sym=gptq_params.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
qzeros = torch.cat(
|
||||
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
from text_generation_server.layers.gptq import HAS_EXLLAMA
|
||||
|
||||
|
@ -425,10 +425,8 @@ class Weights:
|
|||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlin24Weight,
|
||||
MarlinWeight,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
quant_method = getattr(self, "quant_method", "marlin")
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
try:
|
||||
|
@ -452,36 +450,6 @@ class Weights:
|
|||
weight = GPTQMarlin24Weight(
|
||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||
)
|
||||
elif quant_method == "gptq":
|
||||
gptq_params = self._get_gptq_params()
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes],
|
||||
dim=1,
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
scales = torch.cat(
|
||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
weight = repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
desc_act=gptq_params.desc_act,
|
||||
groupsize=gptq_params.groupsize,
|
||||
sym=gptq_params.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
B = torch.cat(
|
||||
|
@ -544,9 +512,41 @@ class Weights:
|
|||
)
|
||||
|
||||
elif quantize == "gptq":
|
||||
use_exllama = True
|
||||
gptq_params = self._get_gptq_params()
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
if can_use_gptq_marlin(gptq_params, quantize):
|
||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
if gptq_params.desc_act or gptq_params.groupsize == -1:
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
else:
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
|
||||
sharded_in_features = self.process_group.size() > 1
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
desc_act=gptq_params.desc_act,
|
||||
groupsize=gptq_params.groupsize,
|
||||
sym=gptq_params.sym,
|
||||
sharded_infeatures=sharded_in_features,
|
||||
)
|
||||
|
||||
use_exllama = True
|
||||
if gptq_params.bits != 4:
|
||||
use_exllama = False
|
||||
|
||||
|
@ -672,10 +672,8 @@ class Weights:
|
|||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlin24Weight,
|
||||
MarlinWeight,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
quant_method = getattr(self, "quant_method", "marlin")
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
try:
|
||||
|
@ -698,35 +696,6 @@ class Weights:
|
|||
weight = GPTQMarlin24Weight(
|
||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||
)
|
||||
elif quant_method == "gptq":
|
||||
log_once(logger.info, "Converting GPTQ model to Marlin packing format.")
|
||||
gptq_params = self._get_gptq_params()
|
||||
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
if gptq_params.desc_act or gptq_params.groupsize == -1:
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
else:
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
|
||||
sharded_in_features = self.process_group.size() > 1
|
||||
|
||||
weight = repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
desc_act=gptq_params.desc_act,
|
||||
groupsize=gptq_params.groupsize,
|
||||
sym=gptq_params.sym,
|
||||
sharded_infeatures=sharded_in_features,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
B = self.get_sharded(f"{prefix}.B", dim=0)
|
||||
|
@ -743,18 +712,17 @@ class Weights:
|
|||
else:
|
||||
s = self.get_sharded(f"{prefix}.s", dim=0)
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
|
||||
else:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
||||
def _get_gptq_params(self) -> _GPTQParams:
|
||||
def _get_gptq_params(self) -> GPTQParams:
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
|
||||
desc_act = False
|
||||
sym = True
|
||||
sym = False
|
||||
quant_method = "gptq"
|
||||
except (SafetensorError, RuntimeError) as e:
|
||||
try:
|
||||
|
@ -767,7 +735,7 @@ class Weights:
|
|||
except Exception:
|
||||
raise e
|
||||
|
||||
return _GPTQParams(
|
||||
return GPTQParams(
|
||||
bits=bits,
|
||||
checkpoint_format=checkpoint_format,
|
||||
desc_act=desc_act,
|
||||
|
|
Loading…
Reference in New Issue