feat(bloom): use torch.nn.Linear and torch.nn.GELU (#33)

This commit is contained in:
OlivierDehaene 2023-01-26 15:33:45 +01:00 committed by GitHub
parent 13e7044ab7
commit ce960be0a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 3 deletions

View File

@ -136,7 +136,6 @@ class BLOOMSharded(BLOOM):
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
tensor = tensor.transpose(1, 0)
else:
size = slice_.get_shape()[0]
block_size = size // world_size
@ -150,7 +149,6 @@ class BLOOMSharded(BLOOM):
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
tensor = tensor.transpose(1, 0)
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
@ -186,7 +184,7 @@ class BLOOMSharded(BLOOM):
and param_name == "weight"
):
tensor = Int8Params(
tensor.transpose(1, 0),
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)