diff --git a/server/text_generation_server/utils/gptq/exllama.py b/server/text_generation_server/utils/gptq/exllama.py index 6a1cf117..7353afb5 100644 --- a/server/text_generation_server/utils/gptq/exllama.py +++ b/server/text_generation_server/utils/gptq/exllama.py @@ -69,10 +69,11 @@ def create_exllama_buffers(): TEMP_STATE, TEMP_DQ = temp_state, temp_dq -class Ex4bitLinear: +class Ex4bitLinear(torch.nn.Module): """Linear layer implementation with per-group 4-bit quantization of the weights""" def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + super().__init__() global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE assert bits == 4