From 6c715f8183e27aad41c1928f978647d54c2ba395 Mon Sep 17 00:00:00 2001 From: Dhruv Srikanth <51223342+DhruvSrikanth@users.noreply.github.com> Date: Wed, 15 May 2024 20:08:32 +0100 Subject: [PATCH] [Bug Fix] Update torch import reference in bnb quantization (#1902) # What does this PR do? Fixes `Import Error` occurring from mismatch of usage between torch.nn.Module and nn.Module. --- server/text_generation_server/layers/bnb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/layers/bnb.py b/server/text_generation_server/layers/bnb.py index d27a33a1..ca39919c 100644 --- a/server/text_generation_server/layers/bnb.py +++ b/server/text_generation_server/layers/bnb.py @@ -70,7 +70,7 @@ class Linear8bitLt(torch.nn.Module): return out -class Linear4bit(nn.Module): +class Linear4bit(torch.nn.Module): def __init__(self, weight, bias, quant_type): super().__init__() self.weight = Params4bit(