From 30620a9a44c3e1357d4f9e9a84b00fa1ede011cc Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 10 Apr 2024 18:38:08 +0200 Subject: [PATCH] hotfix: mixtral --- .../custom_modeling/flash_mixtral_modeling.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 89eb8f43..be8cb965 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -464,9 +464,9 @@ class DenseMoE(nn.Module): class MixtralLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = f"{prefix}.layers.{layer_id}" self.self_attn = MixtralAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights @@ -525,16 +525,20 @@ class MixtralLayer(nn.Module): class MixtralModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=( + "model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" + ), + weights=weights, ) self.layers = nn.ModuleList( [ MixtralLayer( + "model" if not prefix else f"{prefix}.model", layer_id, config, weights, @@ -543,7 +547,9 @@ class MixtralModel(torch.nn.Module): ] ) self.norm = FastRMSNorm.load( - prefix="model.norm", weights=weights, eps=config.rms_norm_eps + prefix="model.norm" if not prefix else f"{prefix}.model.norm", + weights=weights, + eps=config.rms_norm_eps, ) self.head_size = self.layers[0].self_attn.head_size @@ -593,13 +599,13 @@ class MixtralModel(torch.nn.Module): class FlashMixtralForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() - self.model = MixtralModel(config, weights) + self.model = MixtralModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( config, - prefix="lm_head", + prefix="lm_head" if not prefix else f"{prefix}.lm_head", weights=weights, ) self.max_past = config.sliding_window