fix(server): Fixing RW code (it's remote code so the Arch checking doesn't work to see which weights to keep). (#579)
Fixes #555
This commit is contained in:
parent
b4024edd45
commit
f0181436f4
|
@ -49,7 +49,13 @@ class FlashRWSharded(FlashCausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
weights = Weights(
|
||||
filenames,
|
||||
device,
|
||||
dtype,
|
||||
process_group=self.process_group,
|
||||
aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]},
|
||||
)
|
||||
|
||||
config.quantize = quantize
|
||||
|
||||
|
|
Loading…
Reference in New Issue