from dataclasses import dataclass from typing import List, Union import torch from text_generation_server.utils.weights import Weight, Weights, WeightsLoader @dataclass class Exl2Weight(Weight): """ Exllama2 exl2 quantized weights. """ q_weight: torch.Tensor q_scale: torch.Tensor q_invperm: torch.Tensor q_scale_max: torch.Tensor q_groups: torch.Tensor def __post_init__(self): self.q_scale_max /= 256 self.q_invperm = self.q_invperm.short() @property def device(self) -> torch.device: return self.q_weight.device def get_linear(self, bias: torch.Tensor): from text_generation_server.layers.gptq import ExllamaQuantLinear return ExllamaQuantLinear(self, bias) class Exl2WeightsLoader(WeightsLoader): """Loader for exl2-quantized weights.""" def get_weights(self, weights: "Weights", prefix: str): """ Get weights at the given prefix and apply without tensor paralllism. """ try: q_weight = weights.get_tensor(f"{prefix}.q_weight") except RuntimeError: raise RuntimeError( "Cannot load `exl2`-quantized weight, make sure the model is already quantized." ) q_scale = weights.get_tensor(f"{prefix}.q_scale") q_invperm = weights.get_tensor(f"{prefix}.q_invperm") q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") q_groups = weights.get_tensor(f"{prefix}.q_groups") return Exl2Weight( q_weight=q_weight, q_scale=q_scale, q_invperm=q_invperm, q_scale_max=q_scale_max, q_groups=q_groups, ) def get_weights_col_packed( self, weights: Weights, prefix: str, block_sizes: Union[int, List[int]], ): raise RuntimeError("Column-packed weights are not supported for exl") def get_weights_col(self, weights: Weights, prefix: str): # Sharding is not yet supported, so we return the weights as-is. return self.get_weights(weights, prefix) def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): raise ValueError("get_multi_weights_col is not supported for exl2") def get_weights_row(self, weights: Weights, prefix: str): # Sharding is not yet supported, so we return the weights as-is. return self.get_weights(weights, prefix)