fix(server): Fixing RW code (it's remote code so the Arch checking doesn't work to see which weights to keep). (#579)

Fixes #555
This commit is contained in:
Nicolas Patry 2023-07-12 09:51:34 +02:00 committed by GitHub
parent b4024edd45
commit f0181436f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 1 deletions

View File

@ -49,7 +49,13 @@ class FlashRWSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]},
)
config.quantize = quantize