Fix Starcoder2 after refactor (#2189)

This commit is contained in:
Daniël de Kok 2024-07-05 12:22:45 +02:00 committed by GitHub
parent 853d4eb9cf
commit b67d46336e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 11 additions and 6 deletions

View File

@ -417,14 +417,14 @@ class Starcoder2Layer(nn.Module):
class Starcoder2Model(torch.nn.Module): class Starcoder2Model(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights prefix=f"{prefix}.embed_tokens", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
@ -437,7 +437,7 @@ class Starcoder2Model(torch.nn.Module):
] ]
) )
self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
prefix="model.norm", weights=weights, eps=config.norm_epsilon prefix=f"{prefix}.norm", weights=weights, eps=config.norm_epsilon
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
@ -489,10 +489,15 @@ class Starcoder2Model(torch.nn.Module):
class FlashStarcoder2ForCausalLM(torch.nn.Module): class FlashStarcoder2ForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.model = Starcoder2Model(config, weights) if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = Starcoder2Model(prefix, config, weights)
try: try:
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
@ -502,7 +507,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
except RuntimeError: except RuntimeError:
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="model.embed_tokens", prefix=f"{prefix}.embed_tokens",
weights=weights, weights=weights,
) )