Support tied embeddings in 0.5B and 1.5B Qwen2 models (#2313)
This commit is contained in:
parent
3905f854ed
commit
4b49c50f4c
|
@ -262,6 +262,9 @@ class Qwen2Layer(nn.Module):
|
||||||
class Qwen2Model(torch.nn.Module):
|
class Qwen2Model(torch.nn.Module):
|
||||||
def __init__(self, prefix: str, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
prefix = f"{prefix}.model" if prefix else "model"
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
|
@ -335,15 +338,16 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix: str, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if not prefix:
|
|
||||||
prefix = "model"
|
|
||||||
else:
|
|
||||||
prefix = f"{prefix}.model"
|
|
||||||
|
|
||||||
self.model = Qwen2Model(prefix, config, weights)
|
self.model = Qwen2Model(prefix, config, weights)
|
||||||
|
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
suffix = "model.embed_tokens"
|
||||||
|
else:
|
||||||
|
suffix = "lm_head"
|
||||||
|
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="lm_head",
|
prefix=f"{prefix}.{suffix}" if prefix else suffix,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.max_past = config.sliding_window
|
self.max_past = config.sliding_window
|
||||||
|
|
Loading…
Reference in New Issue