From f0181436f41c730aad068029f54a3f86f354442d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 12 Jul 2023 09:51:34 +0200 Subject: [PATCH] fix(server): Fixing RW code (it's remote code so the Arch checking doesn't work to see which weights to keep). (#579) Fixes #555 --- server/text_generation_server/models/flash_rw.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 12b862d7..55d555fc 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -49,7 +49,13 @@ class FlashRWSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) + weights = Weights( + filenames, + device, + dtype, + process_group=self.process_group, + aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]}, + ) config.quantize = quantize