diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json index 26224118..8548e376 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json @@ -11,57 +11,57 @@ }, { "id": 3226, - "logprob": -8.9453125, + "logprob": -9.0234375, "text": " ge" }, { "id": 21017, - "logprob": -8.8515625, + "logprob": -9.0859375, "text": "ometric" }, { "id": 81, - "logprob": -0.21875, + "logprob": -0.25585938, "text": "_" }, { "id": 6009, - "logprob": -1.2773438, + "logprob": -2.1972656, "text": "mean" }, { "id": 26, - "logprob": -0.25195312, + "logprob": -0.2998047, "text": "(" }, { "id": 62, - "logprob": -4.8203125, + "logprob": -5.6445312, "text": "L" }, { "id": 44, - "logprob": -3.7734375, + "logprob": -3.0839844, "text": ":" }, { "id": 1682, - "logprob": -0.8310547, + "logprob": -0.6748047, "text": " List" }, { "id": 77, - "logprob": -0.22766113, + "logprob": -0.3864746, "text": "[" }, { "id": 1808, - "logprob": -0.46240234, + "logprob": -0.9355469, "text": "float" }, { "id": 10794, - "logprob": -3.0234375, + "logprob": -2.5371094, "text": "]):" } ], @@ -69,7 +69,7 @@ "tokens": [ { "id": 284, - "logprob": -0.04626465, + "logprob": -1.1679688, "special": false, "text": "\n " }, diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json index 015912f8..a6b80534 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json @@ -11,57 +11,57 @@ }, { "id": 3226, - "logprob": -8.9453125, + "logprob": -9.015625, "text": " ge" }, { "id": 21017, - "logprob": -8.859375, + "logprob": -9.0859375, "text": "ometric" }, { "id": 81, - "logprob": -0.21984863, + "logprob": -0.25585938, "text": "_" }, { "id": 6009, - "logprob": -1.2861328, + "logprob": -2.2304688, "text": "mean" }, { "id": 26, - "logprob": -0.25219727, + "logprob": -0.29760742, "text": "(" }, { "id": 62, - "logprob": -4.8007812, + "logprob": -5.6796875, "text": "L" }, { "id": 44, - "logprob": -3.7949219, + "logprob": -3.0742188, "text": ":" }, { "id": 1682, - "logprob": -0.8046875, + "logprob": -0.67626953, "text": " List" }, { "id": 77, - "logprob": -0.22424316, + "logprob": -0.38842773, "text": "[" }, { "id": 1808, - "logprob": -0.46191406, + "logprob": -0.9165039, "text": "float" }, { "id": 10794, - "logprob": -3.0253906, + "logprob": -2.5527344, "text": "]):" } ], @@ -69,7 +69,7 @@ "tokens": [ { "id": 284, - "logprob": 0.0, + "logprob": -0.048583984, "special": false, "text": "\n " }, diff --git a/server/text_generation_server/layers/awq/quantize/__init__.py b/server/text_generation_server/layers/awq/quantize/__init__.py new file mode 100644 index 00000000..3e72881b --- /dev/null +++ b/server/text_generation_server/layers/awq/quantize/__init__.py @@ -0,0 +1,8 @@ +from text_generation_server.utils.import_utils import SYSTEM + +if SYSTEM == "ipex": + from .ipex import WQLinear +elif SYSTEM == "cuda": + from .cuda import WQLinear + +__all__ = ["WQLinear"] diff --git a/server/text_generation_server/layers/awq/quantize/qmodule.py b/server/text_generation_server/layers/awq/quantize/cuda.py similarity index 100% rename from server/text_generation_server/layers/awq/quantize/qmodule.py rename to server/text_generation_server/layers/awq/quantize/cuda.py diff --git a/server/text_generation_server/layers/awq/quantize/ipex.py b/server/text_generation_server/layers/awq/quantize/ipex.py new file mode 100644 index 00000000..84cd7a21 --- /dev/null +++ b/server/text_generation_server/layers/awq/quantize/ipex.py @@ -0,0 +1,48 @@ +from typing import Optional +import torch +import torch.nn as nn +import intel_extension_for_pytorch as ipex + + +class WQLinear(nn.Module): + def __init__( + self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor] + ): + super().__init__() + + if w_bit not in [4]: + raise NotImplementedError("Only 4-bit are supported for now.") + + self.in_features = qweight.shape[0] + self.out_features = qweight.shape[1] * 32 // w_bit + + self.w_bit = w_bit + self.group_size = group_size if group_size != -1 else self.in_features + # quick sanity check (make sure aligment) + assert self.in_features % self.group_size == 0 + assert self.out_features % (32 // self.w_bit) == 0 + + self.qweight = qweight + self.qzeros = qzeros + self.scales = scales + self.bias = bias + self.woq_linear = ( + ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight( + self.qweight, + self.scales, + self.qzeros, + self.in_features, + self.out_features, + bias=self.bias, + group_size=self.group_size, + quant_method=ipex.llm.quantization.QuantMethod.AWQ_GEMM, + dtype=ipex.llm.quantization.QuantDtype.INT4, + ) + ) + + @torch.no_grad() + def forward(self, x): + out_shape = x.shape[:-1] + (self.out_features,) + out = self.woq_linear(x.reshape(-1, x.shape[-1])) + out = out + self.bias if self.bias is not None else out + return out.reshape(out_shape) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 1fd183fa..63131dee 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -8,6 +8,11 @@ from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weight, Weights, WeightsLoader +if SYSTEM == "ipex": + from .ipex import QuantLinear +elif SYSTEM == "cuda": + from .cuda import QuantLinear + @dataclass class GPTQWeight(Weight): @@ -36,7 +41,7 @@ class GPTQWeight(Weight): "to use Exllama/GPTQ kernels for AWQ inference." ) try: - from text_generation_server.layers.awq.quantize.qmodule import WQLinear + from text_generation_server.layers.awq.quantize import WQLinear return WQLinear( w_bit=self.bits, @@ -60,8 +65,6 @@ class GPTQWeight(Weight): return ExllamaQuantLinear(self, bias) else: - from text_generation_server.layers.gptq.quant_linear import QuantLinear - return QuantLinear( self.qweight, self.qzeros, @@ -298,6 +301,7 @@ class GPTQWeightsLoader(WeightsLoader): self._get_gptq_params(weights) use_exllama = True + desc_act = self.desc_act if self.bits != 4: use_exllama = False @@ -321,7 +325,8 @@ class GPTQWeightsLoader(WeightsLoader): if g_idx is not None: if ( not torch.equal( - g_idx.cpu(), + # Remove g_idx[0] to adapt the check with TP>1. + (g_idx - g_idx[0]).cpu(), torch.tensor( [i // self.groupsize for i in range(g_idx.shape[0])], dtype=torch.int32, @@ -332,6 +337,7 @@ class GPTQWeightsLoader(WeightsLoader): # Exllama implementation does not support row tensor parallelism with act-order, as # it would require to reorder input activations that are split unto several GPUs use_exllama = False + desc_act = True from text_generation_server.layers.gptq import ( CAN_EXLLAMA, @@ -350,16 +356,16 @@ class GPTQWeightsLoader(WeightsLoader): else: log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") - if use_exllama and self.groupsize != -1: + if not desc_act and self.groupsize != -1: qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) scales = weights.get_sharded(f"{prefix}.scales", dim=0) + if g_idx is not None: + # qzeros, scales sharded, and g_idx must be adjusted accordingly + g_idx = g_idx - g_idx[0] else: qzeros = weights.get_tensor(f"{prefix}.qzeros") scales = weights.get_tensor(f"{prefix}.scales") - if use_exllama and g_idx is not None: - g_idx = g_idx - g_idx[0] - if self.quantize == "gptq" and self.quant_method == "awq": log_once( logger.info, "Converting AWQ model to Exllama/GPTQ packing format." diff --git a/server/text_generation_server/layers/gptq/quant_linear.py b/server/text_generation_server/layers/gptq/cuda.py similarity index 100% rename from server/text_generation_server/layers/gptq/quant_linear.py rename to server/text_generation_server/layers/gptq/cuda.py diff --git a/server/text_generation_server/layers/gptq/ipex.py b/server/text_generation_server/layers/gptq/ipex.py new file mode 100644 index 00000000..ab9c9e24 --- /dev/null +++ b/server/text_generation_server/layers/gptq/ipex.py @@ -0,0 +1,126 @@ +import math +import numpy as np +import torch +import torch.nn as nn + +import intel_extension_for_pytorch as ipex + + +class QuantLinear(nn.Module): + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + super().__init__() + self.register_buffer("qweight", qweight) + self.register_buffer("qzeros", qzeros) + self.register_buffer("scales", scales) + self.register_buffer("g_idx", g_idx) + if bias is not None: + self.register_buffer("bias", bias) + else: + self.bias = None + if bits not in [4]: + raise NotImplementedError("Only 4 bits are supported.") + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize + + self.outfeatures = qweight.shape[1] + self.infeatures = qweight.shape[0] * 32 // bits + self.woq_linear = ( + ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight( + self.qweight, + self.scales, + self.qzeros, + self.infeatures, + self.outfeatures, + bias=self.bias, + group_size=self.groupsize, + g_idx=g_idx, + quant_method=ipex.llm.quantization.QuantMethod.GPTQ_GEMM, + dtype=ipex.llm.quantization.QuantDtype.INT4, + ) + ) + + @classmethod + def new(cls, bits, groupsize, infeatures, outfeatures, bias): + if bits not in [4]: + raise NotImplementedError("Only 4 bits are supported.") + + qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) + qzeros = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), + dtype=torch.int32, + ) + scales = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 + ) + g_idx = torch.tensor( + [i // groupsize for i in range(infeatures)], dtype=torch.int32 + ) + if bias: + bias = torch.zeros((outfeatures), dtype=torch.float16) + else: + bias = None + return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + + def pack(self, linear, scales, zeros, g_idx=None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) + / self.scales[self.g_idx[idx]] + ).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 + ) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [4]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros( + (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 + ) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [4]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures,) + out = self.woq_linear(x.reshape(-1, x.shape[-1])) + out = out + self.bias if self.bias is not None else out + return out.reshape(out_shape) diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index b0086ea0..66fc15ec 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -12,7 +12,7 @@ from huggingface_hub import HfApi from accelerate import init_empty_weights from text_generation_server.utils import initialize_torch_distributed, Weights from text_generation_server.utils.hub import weight_files -from text_generation_server.layers.gptq.quant_linear import QuantLinear +from text_generation_server.layers.gptq import QuantLinear from loguru import logger from typing import Optional from text_generation_server.layers.gptq.utils import torch_snr_error diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index de0c66e7..0860e9ee 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -400,8 +400,11 @@ def get_model( if dtype is None: if quantize in ["awq", "exl2", "gptq", "marlin"]: - # These quantizers only work with float16 params. - dtype = torch.float16 + if SYSTEM == "ipex" and not hasattr(torch, "xpu"): + dtype = torch.bfloat16 + else: + # These quantizers only work with float16 params. + dtype = torch.float16 elif quantize == "fp8": from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7018edb1..b1270b44 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1122,7 +1122,6 @@ class FlashCausalLM(Model): dtype = default_dtype if dtype is None else dtype else: device = torch.device("cpu") - # Float16 doesn't exist on target. dtype = torch.bfloat16 if dtype is None else dtype init_cpu_threads_env(rank_id=rank, world_size=world_size) else: @@ -1602,8 +1601,6 @@ class FlashCausalLM(Model): max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices - print(slots) - if cu_seqlen_prefill is None and self.max_past() is not None: # In decode, not prefill, we're actually overwriting the KV-cache # in a circular buffer mode.