fix(server): Fixing RW code (it's remote code so the Arch checking doesn't work to see which weights to keep). (#579)
Fixes #555
This commit is contained in:
parent
b4024edd45
commit
f0181436f4
|
@ -49,7 +49,13 @@ class FlashRWSharded(FlashCausalLM):
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
weights = Weights(
|
||||||
|
filenames,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
process_group=self.process_group,
|
||||||
|
aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]},
|
||||||
|
)
|
||||||
|
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue