From c6e8b9442b1fcf7bbbe4be58fcd85047f69e4112 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 31 Jan 2023 17:40:38 +0100 Subject: [PATCH] fix(server): fix quantization for sharded models (#45) --- server/text_generation/models/bloom.py | 12 +++--------- server/text_generation/models/galactica.py | 16 ++++------------ 2 files changed, 7 insertions(+), 21 deletions(-) diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index 1b7635c5..7708bb4a 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -196,15 +196,11 @@ class BLOOMSharded(BLOOM): tensor.CB = None tensor.SCB = None - def replace_linear(state, in_features, out_features): + def replace_linear(state): def linear(input, weight, bias): - size_out = input.size()[:-1] + (out_features,) - input = input.view(-1, in_features) - out = input.new_empty(size_out) out = bnb.matmul( input, weight, - out=out.view(-1, out_features), state=state, threshold=state.threshold, bias=bias, @@ -217,13 +213,11 @@ class BLOOMSharded(BLOOM): del state.CB weight.data = state.CxB - return out.view(size_out) + return out return linear - module.linear = replace_linear( - state, module.in_features, module.out_features - ) + module.linear = replace_linear(state) else: tensor = tensor.to(device) diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index 5cc55865..4722e1d8 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -232,7 +232,6 @@ class GalacticaSharded(Galactica): start = rank * block_size stop = (rank + 1) * block_size tensor = slice_[start:stop] - tensor = tensor.transpose(1, 0) else: size = slice_.get_shape()[0] block_size = size // world_size @@ -246,7 +245,6 @@ class GalacticaSharded(Galactica): start = rank * block_size stop = (rank + 1) * block_size tensor = slice_[:, start:stop] - tensor = tensor.transpose(1, 0) else: tensor = slice_[:] # XXX: Hack for Rowlinear to add the bias only once. @@ -282,7 +280,7 @@ class GalacticaSharded(Galactica): and param_name == "weight" ): tensor = Int8Params( - tensor.transpose(1, 0), + tensor, has_fp16_weights=False, requires_grad=False, ).to(device) @@ -296,15 +294,11 @@ class GalacticaSharded(Galactica): tensor.CB = None tensor.SCB = None - def replace_linear(state, in_features, out_features): + def replace_linear(state): def linear(input, weight, bias): - size_out = input.size()[:-1] + (out_features,) - input = input.view(-1, in_features) - out = input.new_empty(size_out) out = bnb.matmul( input, weight, - out=out.view(-1, out_features), state=state, threshold=state.threshold, bias=bias, @@ -317,13 +311,11 @@ class GalacticaSharded(Galactica): del state.CB weight.data = state.CxB - return out.view(size_out) + return out return linear - module.linear = replace_linear( - state, module.in_features, module.out_features - ) + module.linear = replace_linear(state) else: tensor = tensor.to(device)