feat: adjust to load weights
This commit is contained in:
parent
8aece3bd68
commit
cf8fdef9d3
|
@ -136,6 +136,11 @@ class ModelType(enum.Enum):
|
||||||
"name": "Phi 3",
|
"name": "Phi 3",
|
||||||
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
|
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
|
||||||
}
|
}
|
||||||
|
PHI3SMALL = {
|
||||||
|
"type": "phi3small",
|
||||||
|
"name": "Phi 3 Small",
|
||||||
|
"url": "https://huggingface.co/microsoft/Phi-3-small-8k-instruct",
|
||||||
|
}
|
||||||
GEMMA = {
|
GEMMA = {
|
||||||
"type": "gemma",
|
"type": "gemma",
|
||||||
"name": "Gemma",
|
"name": "Gemma",
|
||||||
|
@ -579,7 +584,12 @@ def get_model(
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
elif (
|
||||||
|
model_type == LLAMA
|
||||||
|
or model_type == BAICHUAN
|
||||||
|
or model_type == PHI3
|
||||||
|
or model_type == PHI3SMALL
|
||||||
|
):
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashLlama(
|
return FlashLlama(
|
||||||
model_id,
|
model_id,
|
||||||
|
|
|
@ -70,6 +70,13 @@ def load_attention(config, prefix, weights):
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
elif config.model_type == "phi3small":
|
||||||
|
return TensorParallelColumnLinear.load_qkv(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.query_key_value",
|
||||||
|
weights=weights,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
# otherwise, load the default attention based on the number of heads
|
# otherwise, load the default attention based on the number of heads
|
||||||
return TensorParallelColumnLinear.load_multi(
|
return TensorParallelColumnLinear.load_multi(
|
||||||
|
@ -93,12 +100,20 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
if config.model_type == "phi3small":
|
||||||
config=config,
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
dim=self.head_size,
|
config=config,
|
||||||
base=config.rope_theta,
|
dim=self.head_size,
|
||||||
device=weights.device,
|
base=config.rope_embedding_base,
|
||||||
)
|
device=weights.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.head_size,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size**-0.5
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
|
@ -114,12 +129,21 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
if config.model_type == "phi3small":
|
||||||
config,
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
prefix=f"{prefix}.o_proj",
|
config,
|
||||||
weights=weights,
|
prefix=f"{prefix}.dense",
|
||||||
bias=False,
|
weights=weights,
|
||||||
)
|
bias=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
@ -209,6 +233,13 @@ class LlamaMLP(nn.Module):
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
elif config.model_type == "phi3small":
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.up_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
|
@ -259,13 +290,16 @@ class FlashLlamaLayer(nn.Module):
|
||||||
)
|
)
|
||||||
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
|
if config.model_type == "phi3small":
|
||||||
|
eps = config.layer_norm_epsilon
|
||||||
|
else:
|
||||||
|
eps = config.rms_norm_eps
|
||||||
|
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=eps
|
||||||
)
|
)
|
||||||
self.post_attention_layernorm = FastRMSNorm.load(
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.post_attention_layernorm",
|
prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=eps
|
||||||
weights=weights,
|
|
||||||
eps=config.rms_norm_eps,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -327,11 +361,19 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = FastRMSNorm.load(
|
|
||||||
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
if config.model_type == "phi3small":
|
||||||
weights=weights,
|
self.norm = FastRMSNorm.load(
|
||||||
eps=config.rms_norm_eps,
|
prefix="model.final_layernorm",
|
||||||
)
|
weights=weights,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.norm = FastRMSNorm.load(
|
||||||
|
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue