From 4b49c50f4c9de2a59dc746e0d1f960e358c70fd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 26 Jul 2024 14:57:24 +0200 Subject: [PATCH] Support tied embeddings in 0.5B and 1.5B Qwen2 models (#2313) --- .../custom_modeling/flash_qwen2_modeling.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index a98709c5..f40d126b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -262,6 +262,9 @@ class Qwen2Layer(nn.Module): class Qwen2Model(torch.nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() + + prefix = f"{prefix}.model" if prefix else "model" + process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() @@ -335,15 +338,16 @@ class Qwen2ForCausalLM(torch.nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() - if not prefix: - prefix = "model" - else: - prefix = f"{prefix}.model" - self.model = Qwen2Model(prefix, config, weights) + + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + self.lm_head = SpeculativeHead.load( config, - prefix="lm_head", + prefix=f"{prefix}.{suffix}" if prefix else suffix, weights=weights, ) self.max_past = config.sliding_window