Improve the handling of quantized weights (#2250)

* Improve the handling of quantized weights

Handling of quantized weights was split between two mechanisms:

- For quantized checkpoints, we used the new weight loader
  infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
  instead relied on conditional in `get_linear`.

Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.

This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:

- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
  `get_linear` does not need to know how to handle quantizer linear
  layers.
- All quantizer weights are strongly typed, we don't pass around
  raw tensors.
- We don't have to pass around the `quantizer` string everywhere.

* Exclude non-MLP layers when using FP8 quantization with Llama
This commit is contained in:
Daniël de Kok 2024-07-19 09:37:39 +02:00 committed by GitHub
parent 1d1b1efa01
commit ba291dad9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 936 additions and 265 deletions

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
}

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 5229,
"logprob": -0.6645508,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": 0.0,
"special": false,
"text": ":"
},
{
"id": 6527,
"logprob": -2.2324219,
"special": false,
"text": " Could"
},
{
"id": 451,
"logprob": 0.0,
"special": false,
"text": " not"
},
{
"id": 6088,
"logprob": -1.6074219,
"special": false,
"text": " parse"
},
{
"id": 1243,
"logprob": -1.6298828,
"special": false,
"text": " test"
},
{
"id": 1206,
"logprob": -0.72558594,
"special": false,
"text": " case"
},
{
"id": 1024,
"logprob": -0.40429688,
"special": false,
"text": " name"
},
{
"id": 515,
"logprob": 0.0,
"special": false,
"text": " from"
},
{
"id": 525,
"logprob": -1.2519531,
"special": false,
"text": " '"
}
],
"top_tokens": null
},
"generated_text": "Test request failed: Could not parse test case name from '"
}

View File

@ -0,0 +1,358 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
}
]

View File

@ -100,6 +100,8 @@ async def test_flash_llama_completion_many_prompts_stream(
chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings
chunk = [c for c in chunk if c]
# remove completion marking chunk
chunk = [c for c in chunk if c != " [DONE]"]
# parse json
chunk = [json.loads(c) for c in chunk]

View File

@ -0,0 +1,66 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_marlin24_handle(launcher):
with launcher(
"nm-testing/Llama-2-7b-pruned2.4-Marlin_24", quantize="marlin"
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_marlin(flash_llama_marlin24_handle):
await flash_llama_marlin24_handle.health(300)
return flash_llama_marlin24_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
response = await flash_llama_marlin.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin24_all_params(flash_llama_marlin, response_snapshot):
response = await flash_llama_marlin.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin24_load(
flash_llama_marlin, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_marlin, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -2,6 +2,7 @@ import pytest
import torch
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
WeightsLoader,
)
@ -363,7 +364,10 @@ class MockWeights(Weights):
self.process_group = process_group
self.prefix = prefix
self.weights_loader = (
DefaultWeightsLoader() if weights_loader is None else weights_loader
# We don't need to get linear layers, so just wrap raw tensors.
DefaultWeightsLoader(lambda x: x)
if weights_loader is None
else weights_loader
)
self._handles = {}
@ -632,6 +636,7 @@ def test_get_weights_col_awq(gptq_weights_loader_awq):
g_idx=None,
bits=8.0,
groupsize=2.0,
use_awq_kernel=True,
use_exllama=False,
)
@ -641,6 +646,7 @@ def test_get_weights_col_awq(gptq_weights_loader_awq):
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
@ -669,6 +675,7 @@ def test_get_weights_col_gtpq(gptq_weights_loader):
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0,
groupsize=2.0,
use_awq_kernel=False,
use_exllama=False,
)
@ -678,6 +685,7 @@ def test_get_weights_col_gtpq(gptq_weights_loader):
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
@ -774,6 +782,7 @@ def test_get_weights_col_packed_awq(gptq_weights_loader_awq):
g_idx=None,
bits=8.0,
groupsize=2.0,
use_awq_kernel=True,
use_exllama=False,
)
@ -783,6 +792,7 @@ def test_get_weights_col_packed_awq(gptq_weights_loader_awq):
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
@ -851,6 +861,7 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader):
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0,
groupsize=2.0,
use_awq_kernel=False,
use_exllama=False,
)
@ -860,6 +871,7 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader):
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
@ -922,6 +934,7 @@ def test_get_multi_weights_col_awq(gptq_weights_loader_awq):
g_idx=None,
bits=8.0,
groupsize=2.0,
use_awq_kernel=True,
use_exllama=False,
)
@ -931,6 +944,7 @@ def test_get_multi_weights_col_awq(gptq_weights_loader_awq):
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
@ -983,6 +997,7 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader):
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0,
groupsize=2.0,
use_awq_kernel=False,
use_exllama=False,
)
@ -992,6 +1007,7 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader):
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
@ -1051,6 +1067,7 @@ def test_get_weights_row_awq(gptq_weights_loader_awq):
g_idx=None,
bits=8.0,
groupsize=2.0,
use_awq_kernel=True,
use_exllama=False,
)
@ -1060,6 +1077,7 @@ def test_get_weights_row_awq(gptq_weights_loader_awq):
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
@ -1125,6 +1143,7 @@ def test_get_weights_row_gptq(gptq_weights_loader):
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0,
groupsize=2.0,
use_awq_kernel=False,
use_exllama=False,
)
@ -1134,6 +1153,7 @@ def test_get_weights_row_gptq(gptq_weights_loader):
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"

View File

@ -1,8 +1,11 @@
import torch
from loguru import logger
from dataclasses import dataclass
from functools import lru_cache
import bitsandbytes as bnb
import torch
from bitsandbytes.nn import Int8Params, Params4bit
from loguru import logger
from text_generation_server.utils.weights import Weight
@lru_cache(1)
@ -12,6 +15,14 @@ def warn_deprecate_bnb():
)
@dataclass
class BNBWeight(Weight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0)
class Linear8bitLt(torch.nn.Module):
def __init__(
self,
@ -70,6 +81,22 @@ class Linear8bitLt(torch.nn.Module):
return out
@dataclass
class BNBFP4Weight(Weight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return Linear4bit(self.weight, bias, quant_type="fp4")
@dataclass
class BNBNF4Weight(Weight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return Linear4bit(self.weight, bias, quant_type="nf4")
class Linear4bit(torch.nn.Module):
def __init__(self, weight, bias, quant_type):
super().__init__()

View File

@ -1,5 +1,23 @@
from dataclasses import dataclass
import torch
from EETQ import quant_weights, w8_a16_gemm
from text_generation_server.utils.weights import Weight
@dataclass
class EETQWeight(Weight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
try:
from text_generation_server.layers.eetq import EETQLinear
return EETQLinear(self.weight, bias)
except ImportError:
raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
class EETQLinear(torch.nn.Module):

View File

@ -1,12 +1,12 @@
import torch
from typing import List, Union
from dataclasses import dataclass
from typing import List, Union
from text_generation_server.utils.weights import WeightsLoader, Weights
import torch
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
@dataclass
class Exl2Weight:
class Exl2Weight(Weight):
"""
Exllama2 exl2 quantized weights.
"""
@ -25,6 +25,11 @@ class Exl2Weight:
def device(self) -> torch.device:
return self.q_weight.device
def get_linear(self, bias: torch.Tensor):
from text_generation_server.layers.gptq import ExllamaQuantLinear
return ExllamaQuantLinear(self, bias)
class Exl2WeightsLoader(WeightsLoader):
"""Loader for exl2-quantized weights."""

View File

@ -1,7 +1,8 @@
from enum import Enum, auto
from dataclasses import dataclass
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weight
def get_fp8_linear() -> torch.nn.Module:
@ -37,6 +38,14 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
return qweight, scale
@dataclass
class Fp8Weight(Weight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return get_fp8_linear()(self.weight, bias)
class Fp8Linear(torch.nn.Module):
def __init__(
self,

View File

@ -1,24 +1,23 @@
from dataclasses import dataclass
from loguru import logger
import os
from dataclasses import dataclass
from typing import List, Optional, Union
from safetensors import SafetensorError
from text_generation_server.utils.weights import Weights, WeightsLoader
import torch
from text_generation_server.utils.import_utils import (
SYSTEM,
)
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
@dataclass
class GPTQWeight:
class GPTQWeight(Weight):
qweight: torch.Tensor
qzeros: torch.Tensor
scales: torch.Tensor
g_idx: Optional[torch.Tensor]
bits: int
groupsize: int
use_awq_kernel: bool
use_exllama: bool
def __post_init__(self):
@ -29,6 +28,50 @@ class GPTQWeight:
def device(self) -> torch.device:
return self.qweight.device
def get_linear(self, bias: torch.Tensor):
if self.use_awq_kernel:
if SYSTEM == "rocm":
raise NotImplementedError(
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
"to use Exllama/GPTQ kernels for AWQ inference."
)
try:
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
return WQLinear(
w_bit=self.bits,
group_size=self.groupsize,
qweight=self.qweight,
qzeros=self.qzeros,
scales=self.scales,
bias=bias,
)
except ImportError:
raise NotImplementedError(
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
)
elif self.use_exllama:
try:
from text_generation_server.layers.gptq import ExllamaQuantLinear
except ImportError:
raise NotImplementedError(
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
)
return ExllamaQuantLinear(self, bias)
else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear
return QuantLinear(
self.qweight,
self.qzeros,
self.scales,
self.g_idx,
bias,
self.bits,
self.groupsize,
)
try:
major, _minor = torch.cuda.get_device_capability()
@ -45,6 +88,8 @@ elif CAN_EXLLAMA:
if V2:
from text_generation_server.layers.gptq.exllamav2 import (
QuantLinear as ExllamaQuantLinear,
)
from text_generation_server.layers.gptq.exllamav2 import (
create_exllama_buffers,
set_device,
)
@ -53,6 +98,8 @@ elif CAN_EXLLAMA:
else:
from text_generation_server.layers.gptq.exllama import (
Ex4bitLinear as ExllamaQuantLinear,
)
from text_generation_server.layers.gptq.exllama import (
create_exllama_buffers,
set_device,
)
@ -162,6 +209,7 @@ class GPTQWeightsLoader(WeightsLoader):
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=False,
)
@ -255,6 +303,7 @@ class GPTQWeightsLoader(WeightsLoader):
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=use_exllama,
)
@ -336,8 +385,8 @@ class GPTQWeightsLoader(WeightsLoader):
use_exllama = False
from text_generation_server.layers.gptq import (
HAS_EXLLAMA,
CAN_EXLLAMA,
HAS_EXLLAMA,
GPTQWeight,
)
@ -389,6 +438,7 @@ class GPTQWeightsLoader(WeightsLoader):
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=use_exllama,
)

View File

@ -1,7 +1,8 @@
from typing import Optional
import torch
from torch.nn import functional as F
from text_generation_server.utils.import_utils import SYSTEM
from torch.nn import functional as F
if SYSTEM == "rocm":
try:
@ -90,167 +91,14 @@ class FastLinearROCm(torch.nn.Module):
return F.linear(inp, self.weight, self.bias)
def get_linear(weight, bias, quantize):
if quantize is None:
def get_linear(weight, bias):
# Weights that are loaded through methods that are not
# quantization-aware are still bare tensors. We may want
# to change this in the future.
if isinstance(weight, torch.Tensor):
if SYSTEM == "rocm":
linear = FastLinearROCm(weight, bias)
return FastLinearROCm(weight, bias)
else:
linear = FastLinear(weight, bias)
elif quantize == "eetq":
try:
from text_generation_server.layers.eetq import EETQLinear
return FastLinear(weight, bias)
linear = EETQLinear(weight, bias)
except ImportError:
raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
elif quantize == "fp8":
from text_generation_server.layers.fp8 import get_fp8_linear
linear = get_fp8_linear()(weight, bias)
elif quantize == "bitsandbytes":
try:
from text_generation_server.layers.bnb import (
warn_deprecate_bnb,
Linear8bitLt,
)
except ImportError:
raise NotImplementedError(
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
)
warn_deprecate_bnb()
linear = Linear8bitLt(
weight,
bias,
has_fp16_weights=False,
threshold=6.0,
)
if bias is not None:
linear.bias = nn.Parameter(bias)
elif quantize == "bitsandbytes-fp4":
try:
from text_generation_server.layers.bnb import Linear4bit
except ImportError:
raise NotImplementedError(
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
)
linear = Linear4bit(
weight,
bias,
quant_type="fp4",
)
elif quantize == "bitsandbytes-nf4":
try:
from text_generation_server.layers.bnb import Linear4bit
except ImportError:
raise NotImplementedError(
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
)
linear = Linear4bit(
weight,
bias,
quant_type="nf4",
)
elif quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2Weight
if not isinstance(weight, Exl2Weight):
raise NotImplementedError(
f"The passed weight is not `exl2` compatible, loader needs to be updated."
)
from text_generation_server.layers.gptq import ExllamaQuantLinear
linear = ExllamaQuantLinear(weight, bias)
elif quantize == "gptq":
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
GPTQMarlinLinear,
GPTQMarlinWeight,
)
if isinstance(weight, GPTQMarlinWeight):
linear = GPTQMarlinLinear(
weight=weight,
bias=bias,
)
elif isinstance(weight, GPTQWeight):
if weight.use_exllama:
try:
from text_generation_server.layers.gptq import (
ExllamaQuantLinear,
)
except ImportError:
raise NotImplementedError(
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
)
linear = ExllamaQuantLinear(weight, bias)
else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear
linear = QuantLinear(
weight.qweight,
weight.qzeros,
weight.scales,
weight.g_idx,
bias,
weight.bits,
weight.groupsize,
)
else:
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
)
elif quantize == "awq":
from text_generation_server.layers.gptq import GPTQWeight
if not isinstance(weight, GPTQWeight):
raise NotImplementedError(
f"The passed weight is not `awq` compatible, loader needs to be updated."
)
if SYSTEM == "rocm":
raise NotImplementedError(
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
"to use Exllama/GPTQ kernels for AWQ inference."
)
try:
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
linear = WQLinear(
w_bit=weight.bits,
group_size=weight.groupsize,
qweight=weight.qweight,
qzeros=weight.qzeros,
scales=weight.scales,
bias=bias,
)
except ImportError:
raise NotImplementedError(
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
)
elif quantize == "marlin":
from text_generation_server.layers.marlin import (
GPTQMarlin24Linear,
GPTQMarlin24Weight,
MarlinLinear,
MarlinWeight,
)
if isinstance(weight, GPTQMarlin24Weight):
linear = GPTQMarlin24Linear(
weight=weight,
bias=bias,
)
elif isinstance(weight, MarlinWeight):
linear = MarlinLinear(weight=weight, bias=bias)
else:
raise NotImplementedError(
f"The passed weight is not `marlin` compatible, loader needs to be updated."
)
else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear
return weight.get_linear(bias)

View File

@ -7,7 +7,7 @@ from loguru import logger
from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights, WeightsLoader
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try:
import marlin_kernels
@ -63,8 +63,7 @@ class MarlinWeightsLoader(WeightsLoader):
return weight
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
if self.is_marlin_24:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
@ -101,8 +100,7 @@ class MarlinWeightsLoader(WeightsLoader):
return weight
def get_weights_row(self, weights: Weights, prefix: str):
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
if self.is_marlin_24:
try:
B = weights.get_sharded(f"{prefix}.B_24", dim=0)
except RuntimeError:
@ -201,7 +199,7 @@ def permute_scales(scales: torch.Tensor):
@dataclass
class GPTQMarlinWeight:
class GPTQMarlinWeight(Weight):
"""
Repacked GPTQ Marlin weights.
"""
@ -219,6 +217,12 @@ class GPTQMarlinWeight:
assert self.g_idx.dtype == torch.int32
assert self.perm.dtype == torch.int32
def get_linear(self, bias: torch.Tensor):
return GPTQMarlinLinear(
weight=self,
bias=bias,
)
def repack_gptq_for_marlin(
*,
@ -376,6 +380,12 @@ class GPTQMarlin24Weight:
assert self.B_meta.dtype == torch.int16
assert self.s.dtype == torch.float16
def get_linear(self, bias: torch.Tensor):
return GPTQMarlin24Linear(
weight=self,
bias=bias,
)
class GPTQMarlin24Linear(nn.Module):
def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]):
@ -567,7 +577,7 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor):
@dataclass
class MarlinWeight:
class MarlinWeight(Weight):
"""
Marlin weights.
@ -583,6 +593,9 @@ class MarlinWeight:
assert self.B.dtype == torch.int32
assert self.s.dtype == torch.float16
def get_linear(self, bias: torch.Tensor):
return MarlinLinear(weight=self, bias=bias)
class MarlinLinear(nn.Module):
def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]):

View File

@ -77,7 +77,7 @@ class TensorParallelHead(SuperLayer):
quantize = config.quantize
return TensorParallelHead(
get_linear(weight, bias=None, quantize=quantize),
get_linear(weight, bias=None),
process_group=weights.process_group,
should_gather=should_gather,
)
@ -134,7 +134,7 @@ class TensorParallelColumnLinear(SuperLayer):
raise NotImplementedError("packed_gate_up only implemented without bias")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
linear = get_linear(weight, bias)
return cls(linear)
@classmethod
@ -157,7 +157,7 @@ class TensorParallelColumnLinear(SuperLayer):
raise NotImplementedError("packed_qkv only implemented for baichuan")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
linear = get_linear(weight, bias)
return cls(linear)
@classmethod
@ -167,7 +167,7 @@ class TensorParallelColumnLinear(SuperLayer):
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
linear = get_linear(weight, bias)
return cls(linear)
@classmethod
@ -177,7 +177,7 @@ class TensorParallelColumnLinear(SuperLayer):
for prefix in prefixes:
weight = weights.get_weights_col(prefix)
b = weights.get_tensor(f"{prefix}.bias") if bias else None
linears.append(get_linear(weight, b, config.quantize))
linears.append(get_linear(weight, b))
linear = LayerConcat(linears)
else:
weight = weights.get_multi_weights_col(prefixes, dim=dim)
@ -186,7 +186,7 @@ class TensorParallelColumnLinear(SuperLayer):
bias = torch.cat(b, dim=dim)
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
linear = get_linear(weight, bias)
return cls(linear)
@ -205,7 +205,7 @@ class TensorParallelRowLinear(SuperLayer):
else:
bias = None
return cls(
get_linear(weight, bias, config.quantize),
get_linear(weight, bias),
process_group=weights.process_group,
)

View File

@ -186,9 +186,7 @@ def _load_gqa(config, prefix: str, weights):
else:
bias = None
return TensorParallelColumnLinear(
get_linear(weight, bias=bias, quantize=config.quantize)
)
return TensorParallelColumnLinear(get_linear(weight, bias=bias))
class FlashCohereAttention(torch.nn.Module):

View File

@ -247,10 +247,10 @@ def _load_experts_quantized(config, prefix, weights, cls):
if cls == TensorParallelRowLinear:
expert_slice = expert_slice.t().contiguous()
linear = get_linear(expert_slice, None, config.quantize)
linear = get_linear(expert_slice, None)
experts.append(cls(linear, weights.process_group))
else:
linear = get_linear(expert_slice, None, config.quantize)
linear = get_linear(expert_slice, None)
experts.append(cls(linear))
return experts

View File

@ -155,9 +155,7 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
)
return TensorParallelColumnLinear(get_linear(weight, bias=None))
class FlashGemma2Attention(torch.nn.Module):

View File

@ -155,9 +155,7 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
)
return TensorParallelColumnLinear(get_linear(weight, bias=None))
class FlashGemmaAttention(torch.nn.Module):

View File

@ -82,7 +82,7 @@ def _load_qkv_gptq(config, prefix: str, weights):
bias = torch.cat(tensors, dim=0)
bias = bias.to(device=weights.device)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
return TensorParallelColumnLinear(get_linear(weight, bias))
def _load_qkv(config, prefix: str, weights, head_size, num_heads):
@ -129,7 +129,7 @@ def _load_qkv(config, prefix: str, weights, head_size, num_heads):
3 * num_heads * head_size
], f"{weight.shape} != {[3 * num_heads * head_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
return TensorParallelColumnLinear(get_linear(weight, bias))
def load_row(config, prefix: str, weights, bias: bool):
@ -147,7 +147,7 @@ def load_row(config, prefix: str, weights, bias: bool):
bias = None
return TensorParallelRowLinear(
get_linear(weight, bias, config.quantize), process_group=weights.process_group
get_linear(weight, bias), process_group=weights.process_group
)
@ -163,7 +163,7 @@ def load_col(config, prefix: str, weights, bias: bool):
else:
bias = None
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
return TensorParallelColumnLinear(get_linear(weight, bias))
class FlashGPT2Attention(torch.nn.Module):

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import List, Optional, Tuple
import torch
@ -25,7 +26,6 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
@ -42,10 +42,16 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.fp8 import Fp8Weight
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
)
if SYSTEM == "rocm":
try:
@ -105,6 +111,19 @@ def load_attention(config, prefix: str, weights, layer_id):
)
@contextmanager
def no_fp8(weights: Weights):
weights_loader = weights.weights_loader
if (
isinstance(weights_loader, DefaultWeightsLoader)
and weights_loader.weight_class is Fp8Weight
):
weights_loader = DefaultWeightsLoader(UnquantizedWeight)
with weights.use_loader(weights_loader):
yield
class FlashLlamaAttention(torch.nn.Module):
def __init__(
self,
@ -330,12 +349,15 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module):
def __init__(self, index, prefix, config, weights):
super().__init__()
self.self_attn = FlashLlamaAttention(
index=index,
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
)
with no_fp8(weights):
self.self_attn = FlashLlamaAttention(
index=index,
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
)
self.mlp = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
)
@ -470,23 +492,27 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
),
weights=weights,
)
with no_fp8(weights):
self.embed_tokens = TensorParallelEmbedding(
prefix=(
"model.embed_tokens"
if not prefix
else f"{prefix}.model.embed_tokens"
),
weights=weights,
)
self.model = FlashLlamaModel(prefix, config, weights)
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
suffix = "lm_head"
self.lm_head = SpeculativeHead.load(
config,
prefix=suffix if not prefix else f"{prefix}.{suffix}",
weights=weights,
)
with no_fp8(weights):
self.lm_head = SpeculativeHead.load(
config,
prefix=suffix if not prefix else f"{prefix}.{suffix}",
weights=weights,
)
def forward(
self,

View File

@ -149,9 +149,7 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
)
return TensorParallelColumnLinear(get_linear(weight, bias=None))
def _load_experts(config, prefix: str, mat, weights):

View File

@ -56,7 +56,7 @@ def load_row(config, prefix: str, weights, bias: bool):
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
linear = get_linear(weight, bias)
if config.use_parallel_residual:
return linear
else:
@ -81,7 +81,7 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
linear = get_linear(weight, bias, config.quantize)
linear = get_linear(weight, bias)
if config.use_parallel_residual:
return linear
else:

View File

@ -100,9 +100,7 @@ def _load_gqa(config, prefix: str, weights):
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
# this is the same as llama except for Phi uses bias=True
return TensorParallelColumnLinear(
get_linear(weight, bias=True, quantize=config.quantize)
)
return TensorParallelColumnLinear(get_linear(weight, bias=True))
class FlashPhiAttention(torch.nn.Module):

View File

@ -31,7 +31,7 @@ def load_row(config, prefix: str, weights, bias: bool):
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
linear = get_linear(weight, bias)
if config.parallel_attn:
return linear
else:

View File

@ -105,6 +105,7 @@ def _load_multi_mqa_gptq(
g_idx=g_idx,
bits=loader.bits,
groupsize=loader.groupsize,
use_awq_kernel=loader.quantize == "awq",
use_exllama=HAS_EXLLAMA,
)
@ -121,7 +122,7 @@ def _load_multi_mqa_gptq(
bias = torch.cat([q_tensor, kv_tensor], dim=0)
bias = bias.to(device=weights.device)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
return TensorParallelColumnLinear(get_linear(weight, bias))
else:
raise NotImplementedError("Gptq loading with santacoder is not implemented")
@ -193,7 +194,7 @@ def _load_multi_mqa(
assert list(bias.shape) == [
(num_heads + 2) * head_size
], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
return TensorParallelColumnLinear(get_linear(weight, bias))
def load_col(config, prefix: str, weights, bias: bool):
@ -206,7 +207,7 @@ def load_col(config, prefix: str, weights, bias: bool):
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
return TensorParallelColumnLinear(get_linear(weight, bias))
def load_row(config, prefix: str, weights, bias: bool):
@ -221,7 +222,7 @@ def load_row(config, prefix: str, weights, bias: bool):
else:
bias = None
return TensorParallelRowLinear(
get_linear(weight, bias, config.quantize), process_group=weights.process_group
get_linear(weight, bias), process_group=weights.process_group
)

View File

@ -149,9 +149,7 @@ def _load_gqa(config, prefix: str, weights):
else:
bias = None
return TensorParallelColumnLinear(
get_linear(weight, bias=bias, quantize=config.quantize)
)
return TensorParallelColumnLinear(get_linear(weight, bias=bias))
class Starcoder2Attention(torch.nn.Module):

View File

@ -34,7 +34,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.utils.weights import DefaultWeightsLoader
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -698,7 +698,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
self.dtype = weights.dtype
# The vision and connector models are not quantized.
with weights.use_loader(DefaultWeightsLoader()):
with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
self.vision_model = Idefics2VisionTransformer(
prefix=(
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
@ -707,16 +707,12 @@ class Idefics2ForConditionalGeneration(nn.Module):
weights=weights,
)
quantize = config.quantize
try:
config.quantize = None
self.connector = Idefics2Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,
)
finally:
config.quantize = quantize
config.quantize = None
self.connector = Idefics2Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,
)
self.config = config
self.image_seq_len = config.perceiver_config.resampler_n_latents

View File

@ -75,7 +75,7 @@ def load_col(config, prefix, weights, bias):
bias = bias.to(device=weights.device)
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
linear = get_linear(weight, bias)
return TensorParallelColumnLinear(linear)

View File

@ -1,11 +1,14 @@
from typing import Optional
import os
import json
import os
from dataclasses import dataclass
from typing import Optional
from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import DefaultWeightsLoader, WeightsLoader
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
WeightsLoader,
)
@dataclass
@ -104,10 +107,30 @@ def get_loader(
quantize=quantize,
sym=quantizer_config.sym,
)
elif quantize == "bitsandbytes":
from text_generation_server.layers.bnb import BNBWeight
return DefaultWeightsLoader(BNBWeight)
elif quantize == "bitsandbytes-fp4":
from text_generation_server.layers.bnb import BNBFP4Weight
return DefaultWeightsLoader(BNBFP4Weight)
elif quantize == "bitsandbytes-nf4":
from text_generation_server.layers.bnb import BNBNF4Weight
return DefaultWeightsLoader(BNBNF4Weight)
elif quantize == "eetq":
from text_generation_server.layers.eetq import EETQWeight
return DefaultWeightsLoader(EETQWeight)
elif quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2WeightsLoader
return Exl2WeightsLoader()
elif quantize == "fp8":
from text_generation_server.layers.fp8 import Fp8Weight
return DefaultWeightsLoader(Fp8Weight)
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeightsLoader
@ -115,5 +138,7 @@ def get_loader(
bits=quantizer_config.bits,
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
)
elif quantize is None:
return DefaultWeightsLoader(UnquantizedWeight)
else:
return DefaultWeightsLoader()
raise ValueError(f"Unknown quantization method: {quantize}")

View File

@ -1,9 +1,13 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import Dict, List, Optional, Union
from safetensors import safe_open
import torch
from safetensors import safe_open
from text_generation_server.utils.import_utils import SYSTEM
class WeightsLoader(ABC):
@ -62,7 +66,39 @@ class WeightsLoader(ABC):
...
class Weight(ABC):
"""Instances of this type implement unquantized/quantized/to-be
quantized weights."""
@abstractmethod
def get_linear(self, bias: torch.Tensor):
"""Create a linear layer from this weight."""
...
@dataclass
class UnquantizedWeight:
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
from text_generation_server.layers.linear import FastLinear, FastLinearROCm
if SYSTEM == "rocm":
return FastLinearROCm(self.weight, bias)
else:
return FastLinear(self.weight, bias)
class DefaultWeightsLoader(WeightsLoader):
"""Weight loader that loads (unquantized) Torch tensors."""
def __init__(self, weight_class):
"""Create a loader. Weights will be wrapped using the given `weights_class`,
normally this will be `UnquantizedWeight`, but a quantizer-specific class
such as `Fp8Weight` can be used to quantize the weights during loading.
"""
self.weight_class = weight_class
"""
Loader that uses tensors as-is with the exception of applying sharding
and/or concatenation.
@ -74,16 +110,21 @@ class DefaultWeightsLoader(WeightsLoader):
prefix: str,
block_sizes: Union[int, List[int]],
):
return weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
return self.weight_class(
weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
),
)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
return torch.cat(w, dim=dim)
return self.weight_class(torch.cat(w, dim=dim))
def get_weights_row(self, weights: "Weights", prefix: str):
return weights.get_sharded(f"{prefix}.weight", dim=1)
return self.weight_class(
weights.get_sharded(f"{prefix}.weight", dim=1),
)
class Weights: