diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index a0273c37..2b346283 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -417,14 +417,14 @@ class Starcoder2Layer(nn.Module): class Starcoder2Model(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=f"{prefix}.embed_tokens", weights=weights ) self.layers = nn.ModuleList( [ @@ -437,7 +437,7 @@ class Starcoder2Model(torch.nn.Module): ] ) self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( - prefix="model.norm", weights=weights, eps=config.norm_epsilon + prefix=f"{prefix}.norm", weights=weights, eps=config.norm_epsilon ) self.gradient_checkpointing = False @@ -489,10 +489,15 @@ class Starcoder2Model(torch.nn.Module): class FlashStarcoder2ForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() - self.model = Starcoder2Model(config, weights) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = Starcoder2Model(prefix, config, weights) try: self.lm_head = SpeculativeHead.load( config, @@ -502,7 +507,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): except RuntimeError: self.lm_head = SpeculativeHead.load( config, - prefix="model.embed_tokens", + prefix=f"{prefix}.embed_tokens", weights=weights, )