fix(server): fix quantization for sharded models (#45)

This commit is contained in:
OlivierDehaene 2023-01-31 17:40:38 +01:00 committed by GitHub
parent 017a2a8c2f
commit c6e8b9442b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 21 deletions

View File

@ -196,15 +196,11 @@ class BLOOMSharded(BLOOM):
tensor.CB = None
tensor.SCB = None
def replace_linear(state, in_features, out_features):
def replace_linear(state):
def linear(input, weight, bias):
size_out = input.size()[:-1] + (out_features,)
input = input.view(-1, in_features)
out = input.new_empty(size_out)
out = bnb.matmul(
input,
weight,
out=out.view(-1, out_features),
state=state,
threshold=state.threshold,
bias=bias,
@ -217,13 +213,11 @@ class BLOOMSharded(BLOOM):
del state.CB
weight.data = state.CxB
return out.view(size_out)
return out
return linear
module.linear = replace_linear(
state, module.in_features, module.out_features
)
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)

View File

@ -232,7 +232,6 @@ class GalacticaSharded(Galactica):
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
tensor = tensor.transpose(1, 0)
else:
size = slice_.get_shape()[0]
block_size = size // world_size
@ -246,7 +245,6 @@ class GalacticaSharded(Galactica):
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
tensor = tensor.transpose(1, 0)
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
@ -282,7 +280,7 @@ class GalacticaSharded(Galactica):
and param_name == "weight"
):
tensor = Int8Params(
tensor.transpose(1, 0),
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)
@ -296,15 +294,11 @@ class GalacticaSharded(Galactica):
tensor.CB = None
tensor.SCB = None
def replace_linear(state, in_features, out_features):
def replace_linear(state):
def linear(input, weight, bias):
size_out = input.size()[:-1] + (out_features,)
input = input.view(-1, in_features)
out = input.new_empty(size_out)
out = bnb.matmul(
input,
weight,
out=out.view(-1, out_features),
state=state,
threshold=state.threshold,
bias=bias,
@ -317,13 +311,11 @@ class GalacticaSharded(Galactica):
del state.CB
weight.data = state.CxB
return out.view(size_out)
return out
return linear
module.linear = replace_linear(
state, module.in_features, module.out_features
)
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)