From a1c23f3823d46dce6e9ac4b3f77cd142650cff2d Mon Sep 17 00:00:00 2001 From: Florian Zimmermeister Date: Tue, 11 Jul 2023 18:47:50 +0200 Subject: [PATCH] Update layers.py --- server/text_generation_server/utils/layers.py | 65 +------------------ 1 file changed, 2 insertions(+), 63 deletions(-) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 8e0362b8..eebdc097 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -8,7 +8,7 @@ from typing import List HAS_BITS_AND_BYTES = True try: import bitsandbytes as bnb - from bitsandbytes.nn import Int8Params + from bitsandbytes.nn import LinearNF4 except ImportError: HAS_BITS_AND_BYTES = False @@ -71,74 +71,13 @@ class FastLinear(nn.Module): def forward(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, self.weight, self.bias) - -class Linear8bitLt(nn.Module): - def __init__( - self, - weight, - bias, - has_fp16_weights=True, - memory_efficient_backward=False, - threshold=0.0, - index=None, - ): - super().__init__() - assert ( - not memory_efficient_backward - ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" - self.state = bnb.MatmulLtState() - self.index = index - - # Necessary for stacked layers - self.state.threshold = threshold - self.state.has_fp16_weights = has_fp16_weights - self.state.memory_efficient_backward = memory_efficient_backward - if threshold > 0.0 and not has_fp16_weights: - self.state.use_pool = True - - self.weight = Int8Params( - weight.data, - has_fp16_weights=has_fp16_weights, - requires_grad=has_fp16_weights, - ) - self.weight.cuda(weight.device) - self.bias = bias - - def init_8bit_state(self): - self.state.CB = self.weight.CB - self.state.SCB = self.weight.SCB - self.weight.CB = None - self.weight.SCB = None - - def forward(self, x: torch.Tensor): - self.state.is_training = self.training - if self.weight.CB is not None: - self.init_8bit_state() - - # weights are cast automatically as Int8Params, but the bias has to be cast manually - if self.bias is not None and self.bias.dtype != x.dtype: - self.bias.data = self.bias.data.to(x.dtype) - - out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - - if not self.state.has_fp16_weights: - if self.state.CB is not None and self.state.CxB is not None: - # we converted 8-bit row major to turing/ampere format in the first inference pass - # we no longer need the row-major weight - del self.state.CB - self.weight.data = self.state.CxB - return out - - def get_linear(weight, bias, quantize): if quantize is None: linear = FastLinear(weight, bias) elif quantize == "bitsandbytes": - linear = Linear8bitLt( + linear = LinearNF4( weight, bias, - has_fp16_weights=False, - threshold=6.0, ) if bias is not None: linear.bias = nn.Parameter(bias)