feat(server): Support BLOOMChat-176B (#348) (#351)

@njhill, 
temporary workaround to be able to run our CI as secrets are not
available to runners run by external contributors. I will ask around to
see if there is a better way.

Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
OlivierDehaene 2023-05-22 13:36:00 +02:00 committed by GitHub
parent 1ba78207e6
commit e649bf9a55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 2 deletions

View File

@ -131,7 +131,10 @@ class BLOOMSharded(BLOOM):
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
full_name = f"transformer.{name}"
if name.startswith("transformer.") or name.startswith("lm_head."):
full_name = name
else:
full_name = f"transformer.{name}"
module_name, param_name = full_name.rsplit(".", 1)
module = model.get_submodule(module_name)
@ -157,7 +160,7 @@ class BLOOMSharded(BLOOM):
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
elif isinstance(module, TensorParallelEmbedding) or name == "lm_head.weight":
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size

View File

@ -504,6 +504,7 @@ class CausalLM(Model):
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
return_dict=True,
)
return outputs.logits, outputs.past_key_values