diff --git a/server/text_generation_server/layers/bnb.py b/server/text_generation_server/layers/bnb.py index d27a33a1..ca39919c 100644 --- a/server/text_generation_server/layers/bnb.py +++ b/server/text_generation_server/layers/bnb.py @@ -70,7 +70,7 @@ class Linear8bitLt(torch.nn.Module): return out -class Linear4bit(nn.Module): +class Linear4bit(torch.nn.Module): def __init__(self, weight, bias, quant_type): super().__init__() self.weight = Params4bit(