feat: adjust to load weights

This commit is contained in:
drbh 2024-06-05 11:48:21 +00:00
parent 8aece3bd68
commit cf8fdef9d3
2 changed files with 74 additions and 22 deletions

View File

@ -136,6 +136,11 @@ class ModelType(enum.Enum):
"name": "Phi 3", "name": "Phi 3",
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
} }
PHI3SMALL = {
"type": "phi3small",
"name": "Phi 3 Small",
"url": "https://huggingface.co/microsoft/Phi-3-small-8k-instruct",
}
GEMMA = { GEMMA = {
"type": "gemma", "type": "gemma",
"name": "Gemma", "name": "Gemma",
@ -579,7 +584,12 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: elif (
model_type == LLAMA
or model_type == BAICHUAN
or model_type == PHI3
or model_type == PHI3SMALL
):
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashLlama( return FlashLlama(
model_id, model_id,

View File

@ -70,6 +70,13 @@ def load_attention(config, prefix, weights):
weights=weights, weights=weights,
bias=bias, bias=bias,
) )
elif config.model_type == "phi3small":
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.query_key_value",
weights=weights,
bias=bias,
)
# otherwise, load the default attention based on the number of heads # otherwise, load the default attention based on the number of heads
return TensorParallelColumnLinear.load_multi( return TensorParallelColumnLinear.load_multi(
@ -93,12 +100,20 @@ class FlashLlamaAttention(torch.nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static( if config.model_type == "phi3small":
config=config, self.rotary_emb = PositionRotaryEmbedding.static(
dim=self.head_size, config=config,
base=config.rope_theta, dim=self.head_size,
device=weights.device, base=config.rope_embedding_base,
) device=weights.device,
)
else:
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5 self.softmax_scale = self.head_size**-0.5
@ -114,12 +129,21 @@ class FlashLlamaAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.o_proj = TensorParallelRowLinear.load( if config.model_type == "phi3small":
config, self.o_proj = TensorParallelRowLinear.load(
prefix=f"{prefix}.o_proj", config,
weights=weights, prefix=f"{prefix}.dense",
bias=False, weights=weights,
) bias=False,
)
else:
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange( self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -209,6 +233,13 @@ class LlamaMLP(nn.Module):
weights=weights, weights=weights,
bias=bias, bias=bias,
) )
elif config.model_type == "phi3small":
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
config,
prefix=f"{prefix}.up_proj",
weights=weights,
bias=bias,
)
else: else:
self.gate_up_proj = TensorParallelColumnLinear.load_multi( self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config, config,
@ -259,13 +290,16 @@ class FlashLlamaLayer(nn.Module):
) )
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
if config.model_type == "phi3small":
eps = config.layer_norm_epsilon
else:
eps = config.rms_norm_eps
self.input_layernorm = FastRMSNorm.load( self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.input_layernorm", weights=weights, eps=eps
) )
self.post_attention_layernorm = FastRMSNorm.load( self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm", prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=eps
weights=weights,
eps=config.rms_norm_eps,
) )
def forward( def forward(
@ -327,11 +361,19 @@ class FlashLlamaModel(torch.nn.Module):
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = FastRMSNorm.load(
prefix="model.norm" if not prefix else f"{prefix}.model.norm", if config.model_type == "phi3small":
weights=weights, self.norm = FastRMSNorm.load(
eps=config.rms_norm_eps, prefix="model.final_layernorm",
) weights=weights,
eps=config.layer_norm_epsilon,
)
else:
self.norm = FastRMSNorm.load(
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.gradient_checkpointing = False self.gradient_checkpointing = False