feat(bloom): use torch.nn.Linear and torch.nn.GELU (#33)
This commit is contained in:
parent
13e7044ab7
commit
ce960be0a5
|
@ -136,7 +136,6 @@ class BLOOMSharded(BLOOM):
|
||||||
start = rank * block_size
|
start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
stop = (rank + 1) * block_size
|
||||||
tensor = slice_[start:stop]
|
tensor = slice_[start:stop]
|
||||||
tensor = tensor.transpose(1, 0)
|
|
||||||
else:
|
else:
|
||||||
size = slice_.get_shape()[0]
|
size = slice_.get_shape()[0]
|
||||||
block_size = size // world_size
|
block_size = size // world_size
|
||||||
|
@ -150,7 +149,6 @@ class BLOOMSharded(BLOOM):
|
||||||
start = rank * block_size
|
start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
stop = (rank + 1) * block_size
|
||||||
tensor = slice_[:, start:stop]
|
tensor = slice_[:, start:stop]
|
||||||
tensor = tensor.transpose(1, 0)
|
|
||||||
else:
|
else:
|
||||||
tensor = slice_[:]
|
tensor = slice_[:]
|
||||||
# XXX: Hack for Rowlinear to add the bias only once.
|
# XXX: Hack for Rowlinear to add the bias only once.
|
||||||
|
@ -186,7 +184,7 @@ class BLOOMSharded(BLOOM):
|
||||||
and param_name == "weight"
|
and param_name == "weight"
|
||||||
):
|
):
|
||||||
tensor = Int8Params(
|
tensor = Int8Params(
|
||||||
tensor.transpose(1, 0),
|
tensor,
|
||||||
has_fp16_weights=False,
|
has_fp16_weights=False,
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
Loading…
Reference in New Issue