Consistently take `prefix` in model constructors (#2191)
* Consistently take `prefix` in model constructors * Release test check fix * Misc refactor-related fixes
This commit is contained in:
parent
67ef0649cf
commit
05c094fcfa
|
@ -153,7 +153,7 @@ jobs:
|
||||||
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
|
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
|
||||||
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
||||||
env:
|
env:
|
||||||
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == 'true') && '--release' || '' }}
|
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
|
@ -16,6 +16,7 @@ from text_generation_server.models.custom_modeling.opt_modeling import OPTForCau
|
||||||
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
||||||
MPTForCausalLM,
|
MPTForCausalLM,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.bloom import BloomCausalLMBatch
|
||||||
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||||
BloomForCausalLM,
|
BloomForCausalLM,
|
||||||
)
|
)
|
||||||
|
@ -522,7 +523,7 @@ def get_model(
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
batch_class=CausalLMBatchKeysLast,
|
batch_class=BloomCausalLMBatch,
|
||||||
)
|
)
|
||||||
elif model_type == MPT:
|
elif model_type == MPT:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
|
|
|
@ -553,7 +553,8 @@ class CausalLM(Model):
|
||||||
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||||
weights._set_gptq_params(model_id, revision)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = model_class(config, weights)
|
prefix = ""
|
||||||
|
model = model_class(prefix, config, weights)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|
|
@ -816,7 +816,7 @@ class BloomModel(BloomPreTrainedModel):
|
||||||
|
|
||||||
|
|
||||||
class BloomForCausalLM(BloomPreTrainedModel):
|
class BloomForCausalLM(BloomPreTrainedModel):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.transformer = BloomModel(config, weights)
|
self.transformer = BloomModel(config, weights)
|
||||||
|
|
||||||
|
|
|
@ -446,7 +446,7 @@ class CLIPEncoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextTransformer(nn.Module):
|
class CLIPTextTransformer(nn.Module):
|
||||||
def __init__(self, config: CLIPTextConfig):
|
def __init__(self, prefix: str, config: CLIPTextConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
embed_dim = config.hidden_size
|
embed_dim = config.hidden_size
|
||||||
|
@ -536,9 +536,9 @@ class CLIPTextModel(CLIPPreTrainedModel):
|
||||||
|
|
||||||
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
|
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
|
||||||
|
|
||||||
def __init__(self, config: CLIPTextConfig):
|
def __init__(self, prefix, config: CLIPTextConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.text_model = CLIPTextTransformer(config)
|
self.text_model = CLIPTextTransformer(prefix, config)
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
|
|
@ -363,9 +363,9 @@ class CohereMLP(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashCohereLayer(nn.Module):
|
class FlashCohereLayer(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, prefix: str, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"model.layers.{layer_id}"
|
prefix = f"{prefix}.layers.{layer_id}"
|
||||||
self.self_attn = FlashCohereAttention(
|
self.self_attn = FlashCohereAttention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
)
|
)
|
||||||
|
@ -416,18 +416,19 @@ class FlashCohereLayer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashCohereModel(torch.nn.Module):
|
class FlashCohereModel(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix="model.embed_tokens", weights=weights
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashCohereLayer(
|
FlashCohereLayer(
|
||||||
|
prefix,
|
||||||
layer_id,
|
layer_id,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
|
@ -436,7 +437,7 @@ class FlashCohereModel(torch.nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = FastLayerNorm.load_no_bias(
|
self.norm = FastLayerNorm.load_no_bias(
|
||||||
prefix="model.norm", weights=weights, eps=config.layer_norm_eps
|
prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_eps
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
@ -486,10 +487,15 @@ class FlashCohereModel(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashCohereForCausalLM(torch.nn.Module):
|
class FlashCohereForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = FlashCohereModel(config, weights)
|
if not prefix:
|
||||||
|
prefix = "model"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.model"
|
||||||
|
|
||||||
|
self.model = FlashCohereModel(prefix, config, weights)
|
||||||
try:
|
try:
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
|
@ -499,7 +505,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="model.embed_tokens",
|
prefix=f"{prefix}.embed_tokens",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.logit_scale = config.logit_scale
|
self.logit_scale = config.logit_scale
|
||||||
|
|
|
@ -593,9 +593,9 @@ class DenseMoE(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class DbrxLayer(nn.Module):
|
class DbrxLayer(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, prefix: str, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"transformer.blocks.{layer_id}"
|
prefix = f"{prefix}.blocks.{layer_id}"
|
||||||
|
|
||||||
self.attn = DbrxNormAttentionNorm(
|
self.attn = DbrxNormAttentionNorm(
|
||||||
prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights
|
prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights
|
||||||
|
@ -637,16 +637,17 @@ class DbrxLayer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class DbrxModel(torch.nn.Module):
|
class DbrxModel(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix="transformer.wte", weights=weights
|
prefix=f"{prefix}.wte", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
DbrxLayer(
|
DbrxLayer(
|
||||||
|
prefix,
|
||||||
layer_id,
|
layer_id,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
|
@ -655,7 +656,7 @@ class DbrxModel(torch.nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = FastLayerNorm.load_no_bias(
|
self.norm = FastLayerNorm.load_no_bias(
|
||||||
prefix="transformer.norm_f", weights=weights, eps=1e-5
|
prefix=f"{prefix}.norm_f", weights=weights, eps=1e-5
|
||||||
)
|
)
|
||||||
|
|
||||||
self.head_size = self.layers[0].attn.self_attn.head_size
|
self.head_size = self.layers[0].attn.self_attn.head_size
|
||||||
|
@ -702,9 +703,14 @@ class DbrxModel(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashDbrxForCausalLM(torch.nn.Module):
|
class FlashDbrxForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if not prefix:
|
||||||
|
prefix = "transformer"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.transformer"
|
||||||
|
|
||||||
self.model = DbrxModel(config, weights)
|
self.model = DbrxModel(config, weights)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
|
|
|
@ -102,7 +102,7 @@ class Gemma2Config(PretrainedConfig):
|
||||||
|
|
||||||
class Gemma2FastRMSNorm(FastRMSNorm):
|
class Gemma2FastRMSNorm(FastRMSNorm):
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, prefix, weights, eps=1e-6):
|
def load(cls, prefix: str, weights, eps=1e-6):
|
||||||
dtype = weights.dtype
|
dtype = weights.dtype
|
||||||
weights.dtype = torch.float32
|
weights.dtype = torch.float32
|
||||||
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
||||||
|
@ -123,7 +123,7 @@ class Gemma2FastRMSNorm(FastRMSNorm):
|
||||||
return hidden_states.to(self.dtype), residual
|
return hidden_states.to(self.dtype), residual
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix: str, weights):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
return _load_gqa(config, prefix, weights)
|
return _load_gqa(config, prefix, weights)
|
||||||
else:
|
else:
|
||||||
|
@ -305,7 +305,7 @@ class Gemma2MLP(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashGemma2Layer(nn.Module):
|
class FlashGemma2Layer(nn.Module):
|
||||||
def __init__(self, prefix, config, weights, causal: bool, is_sliding: bool):
|
def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = FlashGemma2Attention(
|
self.self_attn = FlashGemma2Attention(
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
|
@ -376,7 +376,7 @@ class FlashGemma2Layer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashGemma2Model(torch.nn.Module):
|
class FlashGemma2Model(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights, causal: bool):
|
def __init__(self, prefix: str, config, weights, causal: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
|
@ -442,7 +442,7 @@ class FlashGemma2Model(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashGemma2ForCausalLM(torch.nn.Module):
|
class FlashGemma2ForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights, *, causal: bool = True):
|
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
embed_norm = config.hidden_size**0.5
|
embed_norm = config.hidden_size**0.5
|
||||||
|
|
|
@ -102,7 +102,7 @@ class GemmaConfig(PretrainedConfig):
|
||||||
|
|
||||||
class GemmaFastRMSNorm(FastRMSNorm):
|
class GemmaFastRMSNorm(FastRMSNorm):
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, prefix, weights, eps=1e-6):
|
def load(cls, prefix: str, weights, eps=1e-6):
|
||||||
dtype = weights.dtype
|
dtype = weights.dtype
|
||||||
weights.dtype = torch.float32
|
weights.dtype = torch.float32
|
||||||
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
||||||
|
@ -123,7 +123,7 @@ class GemmaFastRMSNorm(FastRMSNorm):
|
||||||
return hidden_states.to(self.dtype), residual
|
return hidden_states.to(self.dtype), residual
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix: str, weights):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
return _load_gqa(config, prefix, weights)
|
return _load_gqa(config, prefix, weights)
|
||||||
else:
|
else:
|
||||||
|
@ -261,7 +261,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class GemmaMLP(nn.Module):
|
class GemmaMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act = config.hidden_act
|
act = config.hidden_act
|
||||||
self.act = (
|
self.act = (
|
||||||
|
@ -299,7 +299,7 @@ class GemmaMLP(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashGemmaLayer(nn.Module):
|
class FlashGemmaLayer(nn.Module):
|
||||||
def __init__(self, prefix, config, weights, causal: bool):
|
def __init__(self, prefix: str, config, weights, causal: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = FlashGemmaAttention(
|
self.self_attn = FlashGemmaAttention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
|
prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
|
||||||
|
@ -354,7 +354,7 @@ class FlashGemmaLayer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashGemmaModel(torch.nn.Module):
|
class FlashGemmaModel(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights, causal: bool):
|
def __init__(self, prefix: str, config, weights, causal: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
|
@ -419,7 +419,7 @@ class FlashGemmaModel(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashGemmaForCausalLM(torch.nn.Module):
|
class FlashGemmaForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights, *, causal: bool = True):
|
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
embed_norm = config.hidden_size**0.5
|
embed_norm = config.hidden_size**0.5
|
||||||
|
|
|
@ -261,7 +261,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class GPT2MLP(nn.Module):
|
class GPT2MLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act = config.activation_function
|
act = config.activation_function
|
||||||
self.act = (
|
self.act = (
|
||||||
|
@ -298,7 +298,7 @@ class GPT2MLP(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashGPT2Layer(nn.Module):
|
class FlashGPT2Layer(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = FlashGPT2Attention(
|
self.self_attn = FlashGPT2Attention(
|
||||||
prefix=f"{prefix}.attn", config=config, weights=weights
|
prefix=f"{prefix}.attn", config=config, weights=weights
|
||||||
|
@ -350,7 +350,7 @@ class FlashGPT2Layer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashGPT2Model(torch.nn.Module):
|
class FlashGPT2Model(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
|
@ -414,7 +414,7 @@ class FlashGPT2Model(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashGPT2ForCausalLM(torch.nn.Module):
|
class FlashGPT2ForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
|
|
@ -54,7 +54,7 @@ if SYSTEM == "rocm":
|
||||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights, layer_id):
|
def load_attention(config, prefix: str, weights, layer_id):
|
||||||
# Only defined in granite.
|
# Only defined in granite.
|
||||||
bias = getattr(config, "attention_bias", False)
|
bias = getattr(config, "attention_bias", False)
|
||||||
head_size = config.hidden_size // config.num_attention_heads
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
@ -467,7 +467,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
|
|
@ -248,7 +248,7 @@ class MistralAttention(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class MistralMLP(nn.Module):
|
class MistralMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights, layer_id):
|
def __init__(self, prefix: str, config, weights, layer_id):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_act = config.hidden_act
|
self.hidden_act = config.hidden_act
|
||||||
self.act = (
|
self.act = (
|
||||||
|
@ -328,7 +328,7 @@ class MistralMLP(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class MistralLayer(nn.Module):
|
class MistralLayer(nn.Module):
|
||||||
def __init__(self, prefix, config, weights, layer_id):
|
def __init__(self, prefix: str, config, weights, layer_id):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = MistralAttention(
|
self.self_attn = MistralAttention(
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
|
@ -392,7 +392,7 @@ class MistralLayer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class MistralModel(torch.nn.Module):
|
class MistralModel(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
|
@ -462,7 +462,7 @@ class MistralModel(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashMistralForCausalLM(torch.nn.Module):
|
class FlashMistralForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights, name=None):
|
def __init__(self, prefix: str, config, weights, name=None):
|
||||||
if name is None:
|
if name is None:
|
||||||
name = "model"
|
name = "model"
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -116,7 +116,7 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor:
|
||||||
return x.view(1) if len(x.size()) == 0 else x
|
return x.view(1) if len(x.size()) == 0 else x
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix: str, weights):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
return _load_gqa(config, prefix, weights)
|
return _load_gqa(config, prefix, weights)
|
||||||
else:
|
else:
|
||||||
|
@ -155,7 +155,7 @@ def _load_gqa(config, prefix: str, weights):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _load_experts(config, prefix, mat, weights):
|
def _load_experts(config, prefix: str, mat, weights):
|
||||||
if config.quantize is not None:
|
if config.quantize is not None:
|
||||||
raise NotImplementedError("Mixtral does not support weight quantization yet.")
|
raise NotImplementedError("Mixtral does not support weight quantization yet.")
|
||||||
|
|
||||||
|
@ -475,7 +475,7 @@ class DenseMoE(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class MixtralLayer(nn.Module):
|
class MixtralLayer(nn.Module):
|
||||||
def __init__(self, prefix, layer_id, config, weights):
|
def __init__(self, prefix: str, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"{prefix}.layers.{layer_id}"
|
prefix = f"{prefix}.layers.{layer_id}"
|
||||||
|
|
||||||
|
@ -536,7 +536,7 @@ class MixtralLayer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class MixtralModel(torch.nn.Module):
|
class MixtralModel(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
@ -610,7 +610,7 @@ class MixtralModel(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashMixtralForCausalLM(torch.nn.Module):
|
class FlashMixtralForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = MixtralModel(prefix, config, weights)
|
self.model = MixtralModel(prefix, config, weights)
|
||||||
|
|
|
@ -305,12 +305,12 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
|
||||||
|
|
||||||
|
|
||||||
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.embed_in = TensorParallelEmbedding(
|
self.embed_in = TensorParallelEmbedding(
|
||||||
prefix="gpt_neox.embed_in", weights=weights
|
prefix=f"{prefix}.embed_in", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
|
@ -320,7 +320,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.final_layer_norm = FastLayerNorm.load(
|
self.final_layer_norm = FastLayerNorm.load(
|
||||||
prefix="gpt_neox.final_layer_norm",
|
prefix=f"{prefix}.final_layer_norm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_eps,
|
eps=config.layer_norm_eps,
|
||||||
)
|
)
|
||||||
|
@ -370,9 +370,15 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||||
|
|
||||||
|
|
||||||
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.gpt_neox = FlashGPTNeoXModel(config, weights)
|
|
||||||
|
if not prefix:
|
||||||
|
prefix = "gpt_neox"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.gpt_neox"
|
||||||
|
|
||||||
|
self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights)
|
||||||
|
|
||||||
self.embed_out = SpeculativeHead.load(
|
self.embed_out = SpeculativeHead.load(
|
||||||
config, prefix="embed_out", weights=weights
|
config, prefix="embed_out", weights=weights
|
||||||
|
|
|
@ -258,9 +258,9 @@ class PhiMLP(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashPhiLayer(nn.Module):
|
class FlashPhiLayer(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, prefix: str, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"model.layers.{layer_id}"
|
prefix = f"{prefix}.layers.{layer_id}"
|
||||||
self.self_attn = FlashPhiAttention(
|
self.self_attn = FlashPhiAttention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
)
|
)
|
||||||
|
@ -307,18 +307,19 @@ class FlashPhiLayer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashPhiModel(torch.nn.Module):
|
class FlashPhiModel(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix="model.embed_tokens", weights=weights
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashPhiLayer(
|
FlashPhiLayer(
|
||||||
|
prefix,
|
||||||
layer_id,
|
layer_id,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
|
@ -378,10 +379,15 @@ class FlashPhiModel(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashPhiForCausalLM(torch.nn.Module):
|
class FlashPhiForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = FlashPhiModel(config, weights)
|
if not prefix:
|
||||||
|
prefix = "model"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.model"
|
||||||
|
|
||||||
|
self.model = FlashPhiModel(prefix, config, weights)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="lm_head",
|
prefix="lm_head",
|
||||||
|
|
|
@ -203,9 +203,9 @@ class Qwen2MLP(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Qwen2Layer(nn.Module):
|
class Qwen2Layer(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, prefix, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"model.layers.{layer_id}"
|
prefix = f"{prefix}.layers.{layer_id}"
|
||||||
self.self_attn = Qwen2Attention(
|
self.self_attn = Qwen2Attention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
)
|
)
|
||||||
|
@ -260,17 +260,18 @@ class Qwen2Layer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Qwen2Model(torch.nn.Module):
|
class Qwen2Model(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix="model.embed_tokens", weights=weights
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Qwen2Layer(
|
Qwen2Layer(
|
||||||
|
prefix,
|
||||||
layer_id,
|
layer_id,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
|
@ -279,7 +280,7 @@ class Qwen2Model(torch.nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = FastRMSNorm.load(
|
||||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
@ -331,10 +332,15 @@ class Qwen2Model(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Qwen2ForCausalLM(torch.nn.Module):
|
class Qwen2ForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = Qwen2Model(config, weights)
|
if not prefix:
|
||||||
|
prefix = "model"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.model"
|
||||||
|
|
||||||
|
self.model = Qwen2Model(prefix, config, weights)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="lm_head",
|
prefix="lm_head",
|
||||||
|
|
|
@ -127,7 +127,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
prefix,
|
prefix: str,
|
||||||
weights,
|
weights,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -236,7 +236,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
prefix,
|
prefix: str,
|
||||||
weights,
|
weights,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -358,7 +358,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashMLP(nn.Module):
|
class FlashMLP(nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix: str, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.act = torch.nn.functional.gelu
|
self.act = torch.nn.functional.gelu
|
||||||
|
|
||||||
|
@ -380,6 +380,7 @@ class FlashRWLayer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
layer_id,
|
layer_id,
|
||||||
|
prefix: str,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
):
|
):
|
||||||
|
@ -388,7 +389,7 @@ class FlashRWLayer(nn.Module):
|
||||||
parallel_attn = config.parallel_attn
|
parallel_attn = config.parallel_attn
|
||||||
self.parallel_attn = parallel_attn
|
self.parallel_attn = parallel_attn
|
||||||
|
|
||||||
prefix = f"transformer.h.{layer_id}"
|
prefix = f"{prefix}.h.{layer_id}"
|
||||||
|
|
||||||
self.input_layernorm = FastLayerNorm.load(
|
self.input_layernorm = FastLayerNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm",
|
prefix=f"{prefix}.input_layernorm",
|
||||||
|
@ -479,7 +480,7 @@ class FlashRWLayer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashRWLayerNorm(nn.Module):
|
class FlashRWLayerNorm(nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix: str, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_ln = config.num_ln_in_parallel_attn
|
self.num_ln = config.num_ln_in_parallel_attn
|
||||||
|
|
||||||
|
@ -518,9 +519,9 @@ class FlashRWLayerNorm(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashRWLargeLayer(nn.Module):
|
class FlashRWLargeLayer(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, layer_id, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"transformer.h.{layer_id}"
|
prefix = f"{prefix}.h.{layer_id}"
|
||||||
|
|
||||||
self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
|
self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
|
||||||
|
|
||||||
|
@ -580,18 +581,18 @@ class FlashRWPreTrainedModel(PreTrainedModel):
|
||||||
|
|
||||||
|
|
||||||
class FlashRWModel(FlashRWPreTrainedModel):
|
class FlashRWModel(FlashRWPreTrainedModel):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.word_embeddings = TensorParallelEmbedding(
|
self.word_embeddings = TensorParallelEmbedding(
|
||||||
prefix="transformer.word_embeddings", weights=weights
|
prefix=f"{prefix}.word_embeddings", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.new_decoder_architecture:
|
if config.new_decoder_architecture:
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashRWLargeLayer(layer_id, config, weights)
|
FlashRWLargeLayer(layer_id, prefix, config, weights)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -599,14 +600,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||||
else:
|
else:
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashRWLayer(layer_id, config, weights)
|
FlashRWLayer(layer_id, prefix, config, weights)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.cache_size = self.h[0].self_attention.num_heads_kv
|
self.cache_size = self.h[0].self_attention.num_heads_kv
|
||||||
|
|
||||||
self.ln_f = FastLayerNorm.load(
|
self.ln_f = FastLayerNorm.load(
|
||||||
prefix="transformer.ln_f",
|
prefix=f"{prefix}.ln_f",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_epsilon,
|
eps=config.layer_norm_epsilon,
|
||||||
)
|
)
|
||||||
|
@ -653,10 +654,15 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||||
|
|
||||||
|
|
||||||
class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.transformer = FlashRWModel(config, weights)
|
if not prefix:
|
||||||
|
prefix = "transformer"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.transformer"
|
||||||
|
|
||||||
|
self.transformer = FlashRWModel(prefix, config, weights)
|
||||||
|
|
||||||
self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)
|
self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)
|
||||||
|
|
||||||
|
|
|
@ -346,9 +346,9 @@ class MLP(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, prefix: str, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"transformer.h.{layer_id}"
|
prefix = f"{prefix}.h.{layer_id}"
|
||||||
self.ln_1 = FastLayerNorm.load(
|
self.ln_1 = FastLayerNorm.load(
|
||||||
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
|
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
|
||||||
)
|
)
|
||||||
|
@ -396,18 +396,18 @@ class Block(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashSantacoderModel(nn.Module):
|
class FlashSantacoderModel(nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
self.wte = TensorParallelEmbedding(
|
self.wte = TensorParallelEmbedding(
|
||||||
prefix="transformer.wte",
|
prefix=f"{prefix}.wte",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
reduce=False,
|
reduce=False,
|
||||||
)
|
)
|
||||||
self.wpe = TensorParallelEmbedding(
|
self.wpe = TensorParallelEmbedding(
|
||||||
prefix="transformer.wpe",
|
prefix=f"{prefix}.wpe",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
reduce=False,
|
reduce=False,
|
||||||
)
|
)
|
||||||
|
@ -415,6 +415,7 @@ class FlashSantacoderModel(nn.Module):
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Block(
|
Block(
|
||||||
|
prefix,
|
||||||
layer_id,
|
layer_id,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
|
@ -466,10 +467,16 @@ class FlashSantacoderModel(nn.Module):
|
||||||
class FlashSantacoderForCausalLM(nn.Module):
|
class FlashSantacoderForCausalLM(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if not prefix:
|
||||||
|
prefix = "transformer"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.transformer"
|
||||||
|
|
||||||
config.transpose = config.architectures[0].startswith("GPT2")
|
config.transpose = config.architectures[0].startswith("GPT2")
|
||||||
self.model = FlashSantacoderModel(config, weights)
|
self.model = FlashSantacoderModel(prefix, config, weights)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="transformer.wte", weights=weights
|
config, prefix=f"{prefix}.wte", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|
|
@ -783,7 +783,7 @@ class MPTPreTrainedModel(PreTrainedModel):
|
||||||
|
|
||||||
|
|
||||||
class MPTModel(MPTPreTrainedModel):
|
class MPTModel(MPTPreTrainedModel):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
# config._validate_config()
|
# config._validate_config()
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.world_size = weights.process_group.size()
|
self.world_size = weights.process_group.size()
|
||||||
|
@ -809,13 +809,13 @@ class MPTModel(MPTPreTrainedModel):
|
||||||
f"Requested norm type ({config.norm_type}) is not implemented within this repo."
|
f"Requested norm type ({config.norm_type}) is not implemented within this repo."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.wte = TensorParallelEmbedding("transformer.wte", weights)
|
self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights)
|
||||||
|
|
||||||
if not self.alibi:
|
if not self.alibi:
|
||||||
self.wpe = TensorParallelEmbedding("transformer.wpe", weights)
|
self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights)
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights)
|
MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights)
|
||||||
for i in range(config.n_layers)
|
for i in range(config.n_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -1085,13 +1085,19 @@ class MPTModel(MPTPreTrainedModel):
|
||||||
|
|
||||||
|
|
||||||
class MPTForCausalLM(MPTPreTrainedModel):
|
class MPTForCausalLM(MPTPreTrainedModel):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
if not prefix:
|
||||||
|
prefix = "transformer"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.transformer"
|
||||||
|
|
||||||
if not config.tie_word_embeddings:
|
if not config.tie_word_embeddings:
|
||||||
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
||||||
self.transformer = MPTModel(config, weights)
|
self.transformer = MPTModel(prefix, config, weights)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="transformer.wte", weights=weights
|
config, prefix=f"{prefix}.wte", weights=weights
|
||||||
)
|
)
|
||||||
self.logit_scale = None
|
self.logit_scale = None
|
||||||
if config.logit_scale is not None:
|
if config.logit_scale is not None:
|
||||||
|
|
|
@ -404,24 +404,24 @@ class GPTNeoXMLP(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoXLayer(nn.Module):
|
class GPTNeoXLayer(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, layer_id, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_parallel_residual = config.use_parallel_residual
|
self.use_parallel_residual = config.use_parallel_residual
|
||||||
self.input_layernorm = nn.LayerNorm.load(
|
self.input_layernorm = nn.LayerNorm.load(
|
||||||
prefix=f"gpt_neox.layers.{layer_id}.input_layernorm",
|
prefix=f"{prefix}.layers.{layer_id}.input_layernorm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_eps,
|
eps=config.layer_norm_eps,
|
||||||
)
|
)
|
||||||
self.post_attention_layernorm = nn.LayerNorm.load(
|
self.post_attention_layernorm = nn.LayerNorm.load(
|
||||||
prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm",
|
prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_eps,
|
eps=config.layer_norm_eps,
|
||||||
)
|
)
|
||||||
self.attention = GPTNeoXAttention(
|
self.attention = GPTNeoXAttention(
|
||||||
config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights
|
config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights
|
||||||
)
|
)
|
||||||
self.mlp = GPTNeoXMLP(
|
self.mlp = GPTNeoXMLP(
|
||||||
config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights
|
config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -472,23 +472,23 @@ class GPTNeoXLayer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
|
|
||||||
self.embed_in = TensorParallelEmbedding(
|
self.embed_in = TensorParallelEmbedding(
|
||||||
prefix="gpt_neox.embed_in", weights=weights
|
prefix=f"{prefix}.embed_in", weights=weights
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
GPTNeoXLayer(layer_id, config, weights)
|
GPTNeoXLayer(layer_id, prefix, config, weights)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.final_layer_norm = nn.LayerNorm.load(
|
self.final_layer_norm = nn.LayerNorm.load(
|
||||||
prefix="gpt_neox.final_layer_norm",
|
prefix=f"{prefix}.final_layer_norm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_eps,
|
eps=config.layer_norm_eps,
|
||||||
)
|
)
|
||||||
|
@ -640,9 +640,15 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||||
class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
||||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||||
|
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.gpt_neox = GPTNeoXModel(config, weights)
|
|
||||||
|
if not prefix:
|
||||||
|
prefix = "gpt_neox"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.gpt_neox"
|
||||||
|
|
||||||
|
self.gpt_neox = GPTNeoXModel(prefix, config, weights)
|
||||||
self.embed_out = SpeculativeHead.load(
|
self.embed_out = SpeculativeHead.load(
|
||||||
config, prefix="embed_out", weights=weights
|
config, prefix="embed_out", weights=weights
|
||||||
)
|
)
|
||||||
|
|
|
@ -94,11 +94,11 @@ class OPTLearnedPositionalEmbedding(nn.Module):
|
||||||
This module learns positional embeddings up to a fixed maximum size.
|
This module learns positional embeddings up to a fixed maximum size.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, weights):
|
def __init__(self, prefix: str, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.offset = 2
|
self.offset = 2
|
||||||
self.weight = nn.Parameter(
|
self.weight = nn.Parameter(
|
||||||
weights.get_tensor("model.decoder.embed_positions.weight")
|
weights.get_tensor(f"{prefix}.decoder.embed_positions.weight")
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -311,11 +311,11 @@ class OPTAttention(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class OPTDecoderLayer(nn.Module):
|
class OPTDecoderLayer(nn.Module):
|
||||||
def __init__(self, layer_id: int, config: OPTConfig, weights):
|
def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
prefix = f"model.decoder.layers.{layer_id}"
|
prefix = f"{prefix}.decoder.layers.{layer_id}"
|
||||||
self.self_attn = OPTAttention(
|
self.self_attn = OPTAttention(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
|
@ -429,7 +429,7 @@ class OPTPreTrainedModel(PreTrainedModel):
|
||||||
|
|
||||||
|
|
||||||
class OPTDecoder(OPTPreTrainedModel):
|
class OPTDecoder(OPTPreTrainedModel):
|
||||||
def __init__(self, config: OPTConfig, weights):
|
def __init__(self, prefix: str, config: OPTConfig, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.dropout = config.dropout
|
self.dropout = config.dropout
|
||||||
self.layerdrop = config.layerdrop
|
self.layerdrop = config.layerdrop
|
||||||
|
@ -438,20 +438,26 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix="model.decoder.embed_tokens", weights=weights
|
prefix=f"{prefix}.decoder.embed_tokens", weights=weights
|
||||||
)
|
)
|
||||||
self.embed_positions = OPTLearnedPositionalEmbedding(weights)
|
self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)
|
||||||
|
|
||||||
if config.word_embed_proj_dim != config.hidden_size:
|
if config.word_embed_proj_dim != config.hidden_size:
|
||||||
self.project_out = FastLinear.load(
|
self.project_out = FastLinear.load(
|
||||||
config, prefix="model.decoder.project_out", weights=weights, bias=False
|
config,
|
||||||
|
prefix=f"{prefix}.decoder.project_out",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.project_out = None
|
self.project_out = None
|
||||||
|
|
||||||
if config.word_embed_proj_dim != config.hidden_size:
|
if config.word_embed_proj_dim != config.hidden_size:
|
||||||
self.project_in = FastLinear.load(
|
self.project_in = FastLinear.load(
|
||||||
config, prefix="model.decoder.project_in", weights=weights, bias=False
|
config,
|
||||||
|
prefix=f"{prefix}.decoder.project_in",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.project_in = None
|
self.project_in = None
|
||||||
|
@ -461,14 +467,14 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||||
# see https://github.com/facebookresearch/metaseq/pull/164
|
# see https://github.com/facebookresearch/metaseq/pull/164
|
||||||
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
||||||
self.final_layer_norm = nn.LayerNorm.load(
|
self.final_layer_norm = nn.LayerNorm.load(
|
||||||
prefix="model.decoder.final_layer_norm", weights=weights, eps=EPS
|
prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.final_layer_norm = None
|
self.final_layer_norm = None
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
OPTDecoderLayer(layer_id, config, weights)
|
OPTDecoderLayer(layer_id, prefix, config, weights)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -686,9 +692,9 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||||
|
|
||||||
|
|
||||||
class OPTModel(OPTPreTrainedModel):
|
class OPTModel(OPTPreTrainedModel):
|
||||||
def __init__(self, config: OPTConfig, weights):
|
def __init__(self, prefix: str, config: OPTConfig, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.decoder = OPTDecoder(config, weights)
|
self.decoder = OPTDecoder(prefix, config, weights)
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -743,13 +749,18 @@ class OPTModel(OPTPreTrainedModel):
|
||||||
|
|
||||||
|
|
||||||
class OPTForCausalLM(OPTPreTrainedModel):
|
class OPTForCausalLM(OPTPreTrainedModel):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
if not prefix:
|
||||||
|
prefix = "model"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.model"
|
||||||
|
|
||||||
self.model = OPTModel(config, weights)
|
self.model = OPTModel(config, weights)
|
||||||
|
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="model.decoder.embed_tokens", weights=weights
|
config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|
|
@ -248,16 +248,16 @@ class PhiBlock(nn.Module):
|
||||||
|
|
||||||
# PhiModel implements the embedding layer and the transformer blocks.
|
# PhiModel implements the embedding layer and the transformer blocks.
|
||||||
class PhiModel(nn.Module):
|
class PhiModel(nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_rank = weights.process_group.rank()
|
self.tp_rank = weights.process_group.rank()
|
||||||
self.tp_world_size = weights.process_group.size()
|
self.tp_world_size = weights.process_group.size()
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix="transformer.embd.wte", weights=weights
|
prefix=f"{prefix}.embd.wte", weights=weights
|
||||||
)
|
)
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
PhiBlock(f"transformer.h.{layer_id}", config, weights)
|
PhiBlock(f"{prefix}.h.{layer_id}", config, weights)
|
||||||
for layer_id in range(config.n_layer)
|
for layer_id in range(config.n_layer)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -289,9 +289,15 @@ class PhiModel(nn.Module):
|
||||||
|
|
||||||
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
|
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
|
||||||
class PhiForCausalLM(torch.nn.Module):
|
class PhiForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = PhiModel(config, weights)
|
|
||||||
|
if not prefix:
|
||||||
|
prefix = "transformer"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.transformer"
|
||||||
|
|
||||||
|
self.model = PhiModel(prefix, config, weights)
|
||||||
self.lm_head = PhiCausalLMHead(config, weights)
|
self.lm_head = PhiCausalLMHead(config, weights)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|
|
@ -878,10 +878,6 @@ class FlashCausalLM(Model):
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.speculator = speculator
|
config.speculator = speculator
|
||||||
if getattr(config, "sliding_window", None) is not None:
|
|
||||||
set_sliding_window(config.sliding_window)
|
|
||||||
else:
|
|
||||||
config.sliding_window = None
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@ -900,13 +896,22 @@ class FlashCausalLM(Model):
|
||||||
text_config = getattr(config, "text_config", None)
|
text_config = getattr(config, "text_config", None)
|
||||||
if text_config is not None:
|
if text_config is not None:
|
||||||
config = text_config
|
config = text_config
|
||||||
|
|
||||||
|
if getattr(config, "sliding_window", None) is not None:
|
||||||
|
set_sliding_window(config.sliding_window)
|
||||||
|
else:
|
||||||
|
config.sliding_window = None
|
||||||
|
|
||||||
self.num_layers = config.num_hidden_layers
|
self.num_layers = config.num_hidden_layers
|
||||||
# Validation is done in the model itself
|
# Validation is done in the model itself
|
||||||
if num_kv_heads is None:
|
if num_kv_heads is None:
|
||||||
num_kv_heads = getattr(config, "num_key_value_heads", None)
|
# Order is important here.
|
||||||
|
for attr in ["num_key_value_heads", "num_key_value_heads", "n_head"]:
|
||||||
|
num_kv_heads = getattr(config, "num_attention_heads", None)
|
||||||
|
if num_kv_heads is not None:
|
||||||
|
break
|
||||||
if num_kv_heads is None:
|
if num_kv_heads is None:
|
||||||
# Final overide for GPT2
|
raise ValueError("Cannot get the number of key/value heads")
|
||||||
num_kv_heads = config.n_head
|
|
||||||
self.num_kv_heads = num_kv_heads // self.process_group.size()
|
self.num_kv_heads = num_kv_heads // self.process_group.size()
|
||||||
self.head_size = config.hidden_size // config.num_attention_heads
|
self.head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue