feat(bloom): use torch.nn.Linear and torch.nn.GELU (#33)

This commit is contained in:
OlivierDehaene 2023-01-26 15:33:45 +01:00 committed by GitHub
parent 13e7044ab7
commit ce960be0a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 3 deletions

View File

@ -136,7 +136,6 @@ class BLOOMSharded(BLOOM):
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
tensor = slice_[start:stop] tensor = slice_[start:stop]
tensor = tensor.transpose(1, 0)
else: else:
size = slice_.get_shape()[0] size = slice_.get_shape()[0]
block_size = size // world_size block_size = size // world_size
@ -150,7 +149,6 @@ class BLOOMSharded(BLOOM):
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
tensor = slice_[:, start:stop] tensor = slice_[:, start:stop]
tensor = tensor.transpose(1, 0)
else: else:
tensor = slice_[:] tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once. # XXX: Hack for Rowlinear to add the bias only once.
@ -186,7 +184,7 @@ class BLOOMSharded(BLOOM):
and param_name == "weight" and param_name == "weight"
): ):
tensor = Int8Params( tensor = Int8Params(
tensor.transpose(1, 0), tensor,
has_fp16_weights=False, has_fp16_weights=False,
requires_grad=False, requires_grad=False,
).to(device) ).to(device)