import torch from dataclasses import dataclass @dataclass class Exl2Weight: """ Exllama2 exl2 quantized weights. """ q_weight: torch.Tensor q_scale: torch.Tensor q_invperm: torch.Tensor q_scale_max: torch.Tensor q_groups: torch.Tensor def __post_init__(self): self.q_scale_max /= 256 self.q_invperm = self.q_invperm.short() @property def device(self) -> torch.device: return self.q_weight.device