fix(server): fix quantization for sharded models (#45)
This commit is contained in:
parent
017a2a8c2f
commit
c6e8b9442b
|
@ -196,15 +196,11 @@ class BLOOMSharded(BLOOM):
|
|||
tensor.CB = None
|
||||
tensor.SCB = None
|
||||
|
||||
def replace_linear(state, in_features, out_features):
|
||||
def replace_linear(state):
|
||||
def linear(input, weight, bias):
|
||||
size_out = input.size()[:-1] + (out_features,)
|
||||
input = input.view(-1, in_features)
|
||||
out = input.new_empty(size_out)
|
||||
out = bnb.matmul(
|
||||
input,
|
||||
weight,
|
||||
out=out.view(-1, out_features),
|
||||
state=state,
|
||||
threshold=state.threshold,
|
||||
bias=bias,
|
||||
|
@ -217,13 +213,11 @@ class BLOOMSharded(BLOOM):
|
|||
del state.CB
|
||||
weight.data = state.CxB
|
||||
|
||||
return out.view(size_out)
|
||||
return out
|
||||
|
||||
return linear
|
||||
|
||||
module.linear = replace_linear(
|
||||
state, module.in_features, module.out_features
|
||||
)
|
||||
module.linear = replace_linear(state)
|
||||
|
||||
else:
|
||||
tensor = tensor.to(device)
|
||||
|
|
|
@ -232,7 +232,6 @@ class GalacticaSharded(Galactica):
|
|||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
tensor = tensor.transpose(1, 0)
|
||||
else:
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
|
@ -246,7 +245,6 @@ class GalacticaSharded(Galactica):
|
|||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[:, start:stop]
|
||||
tensor = tensor.transpose(1, 0)
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
# XXX: Hack for Rowlinear to add the bias only once.
|
||||
|
@ -282,7 +280,7 @@ class GalacticaSharded(Galactica):
|
|||
and param_name == "weight"
|
||||
):
|
||||
tensor = Int8Params(
|
||||
tensor.transpose(1, 0),
|
||||
tensor,
|
||||
has_fp16_weights=False,
|
||||
requires_grad=False,
|
||||
).to(device)
|
||||
|
@ -296,15 +294,11 @@ class GalacticaSharded(Galactica):
|
|||
tensor.CB = None
|
||||
tensor.SCB = None
|
||||
|
||||
def replace_linear(state, in_features, out_features):
|
||||
def replace_linear(state):
|
||||
def linear(input, weight, bias):
|
||||
size_out = input.size()[:-1] + (out_features,)
|
||||
input = input.view(-1, in_features)
|
||||
out = input.new_empty(size_out)
|
||||
out = bnb.matmul(
|
||||
input,
|
||||
weight,
|
||||
out=out.view(-1, out_features),
|
||||
state=state,
|
||||
threshold=state.threshold,
|
||||
bias=bias,
|
||||
|
@ -317,13 +311,11 @@ class GalacticaSharded(Galactica):
|
|||
del state.CB
|
||||
weight.data = state.CxB
|
||||
|
||||
return out.view(size_out)
|
||||
return out
|
||||
|
||||
return linear
|
||||
|
||||
module.linear = replace_linear(
|
||||
state, module.in_features, module.out_features
|
||||
)
|
||||
module.linear = replace_linear(state)
|
||||
|
||||
else:
|
||||
tensor = tensor.to(device)
|
||||
|
|
Loading…
Reference in New Issue