Fix Starcoder2 after refactor (#2189)
This commit is contained in:
parent
853d4eb9cf
commit
b67d46336e
|
@ -417,14 +417,14 @@ class Starcoder2Layer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Starcoder2Model(torch.nn.Module):
|
class Starcoder2Model(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix="model.embed_tokens", weights=weights
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
|
@ -437,7 +437,7 @@ class Starcoder2Model(torch.nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
|
self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
|
||||||
prefix="model.norm", weights=weights, eps=config.norm_epsilon
|
prefix=f"{prefix}.norm", weights=weights, eps=config.norm_epsilon
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
@ -489,10 +489,15 @@ class Starcoder2Model(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = Starcoder2Model(config, weights)
|
if not prefix:
|
||||||
|
prefix = "model"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.model"
|
||||||
|
|
||||||
|
self.model = Starcoder2Model(prefix, config, weights)
|
||||||
try:
|
try:
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
|
@ -502,7 +507,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="model.embed_tokens",
|
prefix=f"{prefix}.embed_tokens",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue