feat: adjust to load weights
This commit is contained in:
parent
8aece3bd68
commit
cf8fdef9d3
|
@ -136,6 +136,11 @@ class ModelType(enum.Enum):
|
|||
"name": "Phi 3",
|
||||
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
|
||||
}
|
||||
PHI3SMALL = {
|
||||
"type": "phi3small",
|
||||
"name": "Phi 3 Small",
|
||||
"url": "https://huggingface.co/microsoft/Phi-3-small-8k-instruct",
|
||||
}
|
||||
GEMMA = {
|
||||
"type": "gemma",
|
||||
"name": "Gemma",
|
||||
|
@ -579,7 +584,12 @@ def get_model(
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
||||
elif (
|
||||
model_type == LLAMA
|
||||
or model_type == BAICHUAN
|
||||
or model_type == PHI3
|
||||
or model_type == PHI3SMALL
|
||||
):
|
||||
if FLASH_ATTENTION:
|
||||
return FlashLlama(
|
||||
model_id,
|
||||
|
|
|
@ -70,6 +70,13 @@ def load_attention(config, prefix, weights):
|
|||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
elif config.model_type == "phi3small":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.query_key_value",
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
# otherwise, load the default attention based on the number of heads
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
|
@ -93,12 +100,20 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
if config.model_type == "phi3small":
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_embedding_base,
|
||||
device=weights.device,
|
||||
)
|
||||
else:
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
||||
|
@ -114,12 +129,21 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
if config.model_type == "phi3small":
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.dense",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
|
@ -209,6 +233,13 @@ class LlamaMLP(nn.Module):
|
|||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
elif config.model_type == "phi3small":
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||
config,
|
||||
prefix=f"{prefix}.up_proj",
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
|
@ -259,13 +290,16 @@ class FlashLlamaLayer(nn.Module):
|
|||
)
|
||||
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
if config.model_type == "phi3small":
|
||||
eps = config.layer_norm_epsilon
|
||||
else:
|
||||
eps = config.rms_norm_eps
|
||||
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=eps
|
||||
)
|
||||
self.post_attention_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -327,11 +361,19 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
if config.model_type == "phi3small":
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="model.final_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
else:
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
|
|
Loading…
Reference in New Issue