Hotfixing after refactor.

This commit is contained in:
Nicolas Patry 2024-07-05 09:25:29 +00:00
parent fb2f74e2b9
commit 853d4eb9cf
No known key found for this signature in database
GPG Key ID: B154A218C20EBBCA
2 changed files with 11 additions and 9 deletions

View File

@ -355,7 +355,7 @@ class Block(nn.Module):
self.ln_2 = FastLayerNorm.load(
prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon
)
self.attn = FlashMQAttention(
self.self_attn = FlashMQAttention(
prefix=f"{prefix}.attn",
config=config,
weights=weights,
@ -378,7 +378,7 @@ class Block(nn.Module):
max_s,
):
hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn(
hidden_states = self.self_attn(
hidden_states,
cu_seqlen_prefill,
kv_cache,
@ -412,7 +412,7 @@ class FlashSantacoderModel(nn.Module):
reduce=False,
)
self.h = nn.ModuleList(
self.layers = nn.ModuleList(
[
Block(
layer_id,
@ -426,8 +426,8 @@ class FlashSantacoderModel(nn.Module):
prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon
)
self.head_size = self.h[0].attn.head_size
self.num_heads = self.h[0].attn.num_heads
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
def forward(
self,
@ -446,7 +446,7 @@ class FlashSantacoderModel(nn.Module):
torch.distributed.all_reduce(hidden_states, group=self.process_group)
residual = None
for i, layer in enumerate(self.h):
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
@ -467,7 +467,7 @@ class FlashSantacoderForCausalLM(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.transpose = config.architectures[0].startswith("GPT2")
self.transformer = FlashSantacoderModel(config, weights)
self.model = FlashSantacoderModel(config, weights)
self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights
)
@ -486,7 +486,7 @@ class FlashSantacoderForCausalLM(nn.Module):
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.transformer(
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,

View File

@ -60,7 +60,7 @@ class Model(ABC):
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
LayerAdapterWeights
)
self.target_to_layer = self.adapter_target_to_layer()
self.target_to_layer = None
self.loaded_adapters = set()
self.static_adapter_id = adapter_id
@ -187,6 +187,8 @@ class Model(ABC):
into model. Otherwise, the adapter weights are applied during the forward
pass and stored separately from the base model parameters.
"""
if self.target_to_layer is None:
self.target_to_layer = self.adapter_target_to_layer()
if adapter_index in self.loaded_adapters:
# Adapter already loaded
return