diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 2bc305fe..daef43cc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -355,7 +355,7 @@ class Block(nn.Module): self.ln_2 = FastLayerNorm.load( prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon ) - self.attn = FlashMQAttention( + self.self_attn = FlashMQAttention( prefix=f"{prefix}.attn", config=config, weights=weights, @@ -378,7 +378,7 @@ class Block(nn.Module): max_s, ): hidden_states, residual = self.ln_1(hidden_states, residual) - hidden_states = self.attn( + hidden_states = self.self_attn( hidden_states, cu_seqlen_prefill, kv_cache, @@ -412,7 +412,7 @@ class FlashSantacoderModel(nn.Module): reduce=False, ) - self.h = nn.ModuleList( + self.layers = nn.ModuleList( [ Block( layer_id, @@ -426,8 +426,8 @@ class FlashSantacoderModel(nn.Module): prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon ) - self.head_size = self.h[0].attn.head_size - self.num_heads = self.h[0].attn.num_heads + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads def forward( self, @@ -446,7 +446,7 @@ class FlashSantacoderModel(nn.Module): torch.distributed.all_reduce(hidden_states, group=self.process_group) residual = None - for i, layer in enumerate(self.h): + for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, residual, @@ -467,7 +467,7 @@ class FlashSantacoderForCausalLM(nn.Module): def __init__(self, prefix, config, weights): super().__init__() config.transpose = config.architectures[0].startswith("GPT2") - self.transformer = FlashSantacoderModel(config, weights) + self.model = FlashSantacoderModel(config, weights) self.lm_head = SpeculativeHead.load( config, prefix="transformer.wte", weights=weights ) @@ -486,7 +486,7 @@ class FlashSantacoderForCausalLM(nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.transformer( + hidden_states = self.model( input_ids, position_ids, cu_seqlen_prefill, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index c90fd38a..09130b85 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -60,7 +60,7 @@ class Model(ABC): self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict( LayerAdapterWeights ) - self.target_to_layer = self.adapter_target_to_layer() + self.target_to_layer = None self.loaded_adapters = set() self.static_adapter_id = adapter_id @@ -187,6 +187,8 @@ class Model(ABC): into model. Otherwise, the adapter weights are applied during the forward pass and stored separately from the base model parameters. """ + if self.target_to_layer is None: + self.target_to_layer = self.adapter_target_to_layer() if adapter_index in self.loaded_adapters: # Adapter already loaded return