feat: adjust to load weights

This commit is contained in:
drbh 2024-06-05 11:48:21 +00:00
parent 8aece3bd68
commit cf8fdef9d3
2 changed files with 74 additions and 22 deletions

View File

@ -136,6 +136,11 @@ class ModelType(enum.Enum):
"name": "Phi 3",
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
}
PHI3SMALL = {
"type": "phi3small",
"name": "Phi 3 Small",
"url": "https://huggingface.co/microsoft/Phi-3-small-8k-instruct",
}
GEMMA = {
"type": "gemma",
"name": "Gemma",
@ -579,7 +584,12 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
elif (
model_type == LLAMA
or model_type == BAICHUAN
or model_type == PHI3
or model_type == PHI3SMALL
):
if FLASH_ATTENTION:
return FlashLlama(
model_id,

View File

@ -70,6 +70,13 @@ def load_attention(config, prefix, weights):
weights=weights,
bias=bias,
)
elif config.model_type == "phi3small":
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.query_key_value",
weights=weights,
bias=bias,
)
# otherwise, load the default attention based on the number of heads
return TensorParallelColumnLinear.load_multi(
@ -93,12 +100,20 @@ class FlashLlamaAttention(torch.nn.Module):
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
if config.model_type == "phi3small":
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_embedding_base,
device=weights.device,
)
else:
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5
@ -114,12 +129,21 @@ class FlashLlamaAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
if config.model_type == "phi3small":
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.dense",
weights=weights,
bias=False,
)
else:
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -209,6 +233,13 @@ class LlamaMLP(nn.Module):
weights=weights,
bias=bias,
)
elif config.model_type == "phi3small":
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
config,
prefix=f"{prefix}.up_proj",
weights=weights,
bias=bias,
)
else:
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
@ -259,13 +290,16 @@ class FlashLlamaLayer(nn.Module):
)
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
if config.model_type == "phi3small":
eps = config.layer_norm_epsilon
else:
eps = config.rms_norm_eps
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
prefix=f"{prefix}.input_layernorm", weights=weights, eps=eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=eps
)
def forward(
@ -327,11 +361,19 @@ class FlashLlamaModel(torch.nn.Module):
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
weights=weights,
eps=config.rms_norm_eps,
)
if config.model_type == "phi3small":
self.norm = FastRMSNorm.load(
prefix="model.final_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
else:
self.norm = FastRMSNorm.load(
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.gradient_checkpointing = False