Hotfixing after refactor.

This commit is contained in:
Nicolas Patry 2024-07-05 09:25:29 +00:00
parent fb2f74e2b9
commit 853d4eb9cf
No known key found for this signature in database
GPG Key ID: B154A218C20EBBCA
2 changed files with 11 additions and 9 deletions

View File

@ -355,7 +355,7 @@ class Block(nn.Module):
self.ln_2 = FastLayerNorm.load( self.ln_2 = FastLayerNorm.load(
prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon
) )
self.attn = FlashMQAttention( self.self_attn = FlashMQAttention(
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
config=config, config=config,
weights=weights, weights=weights,
@ -378,7 +378,7 @@ class Block(nn.Module):
max_s, max_s,
): ):
hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn( hidden_states = self.self_attn(
hidden_states, hidden_states,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
@ -412,7 +412,7 @@ class FlashSantacoderModel(nn.Module):
reduce=False, reduce=False,
) )
self.h = nn.ModuleList( self.layers = nn.ModuleList(
[ [
Block( Block(
layer_id, layer_id,
@ -426,8 +426,8 @@ class FlashSantacoderModel(nn.Module):
prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon
) )
self.head_size = self.h[0].attn.head_size self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.h[0].attn.num_heads self.num_heads = self.layers[0].self_attn.num_heads
def forward( def forward(
self, self,
@ -446,7 +446,7 @@ class FlashSantacoderModel(nn.Module):
torch.distributed.all_reduce(hidden_states, group=self.process_group) torch.distributed.all_reduce(hidden_states, group=self.process_group)
residual = None residual = None
for i, layer in enumerate(self.h): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
@ -467,7 +467,7 @@ class FlashSantacoderForCausalLM(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
config.transpose = config.architectures[0].startswith("GPT2") config.transpose = config.architectures[0].startswith("GPT2")
self.transformer = FlashSantacoderModel(config, weights) self.model = FlashSantacoderModel(config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights config, prefix="transformer.wte", weights=weights
) )
@ -486,7 +486,7 @@ class FlashSantacoderForCausalLM(nn.Module):
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer( hidden_states = self.model(
input_ids, input_ids,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,

View File

@ -60,7 +60,7 @@ class Model(ABC):
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict( self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
LayerAdapterWeights LayerAdapterWeights
) )
self.target_to_layer = self.adapter_target_to_layer() self.target_to_layer = None
self.loaded_adapters = set() self.loaded_adapters = set()
self.static_adapter_id = adapter_id self.static_adapter_id = adapter_id
@ -187,6 +187,8 @@ class Model(ABC):
into model. Otherwise, the adapter weights are applied during the forward into model. Otherwise, the adapter weights are applied during the forward
pass and stored separately from the base model parameters. pass and stored separately from the base model parameters.
""" """
if self.target_to_layer is None:
self.target_to_layer = self.adapter_target_to_layer()
if adapter_index in self.loaded_adapters: if adapter_index in self.loaded_adapters:
# Adapter already loaded # Adapter already loaded
return return