Consistently take `prefix` in model constructors (#2191)

* Consistently take `prefix` in model constructors

* Release test check fix

* Misc refactor-related fixes
This commit is contained in:
Daniël de Kok 2024-07-05 16:07:48 +02:00 committed by GitHub
parent 67ef0649cf
commit 05c094fcfa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 210 additions and 131 deletions

View File

@ -153,7 +153,7 @@ jobs:
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
env:
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == 'true') && '--release' || '' }}
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }}
steps:
- name: Checkout repository
uses: actions/checkout@v4

View File

@ -16,6 +16,7 @@ from text_generation_server.models.custom_modeling.opt_modeling import OPTForCau
from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM,
)
from text_generation_server.models.bloom import BloomCausalLMBatch
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
@ -522,7 +523,7 @@ def get_model(
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
batch_class=CausalLMBatchKeysLast,
batch_class=BloomCausalLMBatch,
)
elif model_type == MPT:
return CausalLM(

View File

@ -553,7 +553,8 @@ class CausalLM(Model):
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = model_class(config, weights)
prefix = ""
model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super().__init__(

View File

@ -816,7 +816,7 @@ class BloomModel(BloomPreTrainedModel):
class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
self.transformer = BloomModel(config, weights)

View File

@ -446,7 +446,7 @@ class CLIPEncoder(nn.Module):
class CLIPTextTransformer(nn.Module):
def __init__(self, config: CLIPTextConfig):
def __init__(self, prefix: str, config: CLIPTextConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
@ -536,9 +536,9 @@ class CLIPTextModel(CLIPPreTrainedModel):
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
def __init__(self, config: CLIPTextConfig):
def __init__(self, prefix, config: CLIPTextConfig):
super().__init__(config)
self.text_model = CLIPTextTransformer(config)
self.text_model = CLIPTextTransformer(prefix, config)
# Initialize weights and apply final processing
self.post_init()

View File

@ -363,9 +363,9 @@ class CohereMLP(nn.Module):
class FlashCohereLayer(nn.Module):
def __init__(self, layer_id, config, weights):
def __init__(self, prefix: str, layer_id, config, weights):
super().__init__()
prefix = f"model.layers.{layer_id}"
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = FlashCohereAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
@ -416,18 +416,19 @@ class FlashCohereLayer(nn.Module):
class FlashCohereModel(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
FlashCohereLayer(
prefix,
layer_id,
config,
weights,
@ -436,7 +437,7 @@ class FlashCohereModel(torch.nn.Module):
]
)
self.norm = FastLayerNorm.load_no_bias(
prefix="model.norm", weights=weights, eps=config.layer_norm_eps
prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_eps
)
self.gradient_checkpointing = False
@ -486,10 +487,15 @@ class FlashCohereModel(torch.nn.Module):
class FlashCohereForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.model = FlashCohereModel(config, weights)
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = FlashCohereModel(prefix, config, weights)
try:
self.lm_head = SpeculativeHead.load(
config,
@ -499,7 +505,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
except RuntimeError:
self.lm_head = SpeculativeHead.load(
config,
prefix="model.embed_tokens",
prefix=f"{prefix}.embed_tokens",
weights=weights,
)
self.logit_scale = config.logit_scale

View File

@ -593,9 +593,9 @@ class DenseMoE(nn.Module):
class DbrxLayer(nn.Module):
def __init__(self, layer_id, config, weights):
def __init__(self, prefix: str, layer_id, config, weights):
super().__init__()
prefix = f"transformer.blocks.{layer_id}"
prefix = f"{prefix}.blocks.{layer_id}"
self.attn = DbrxNormAttentionNorm(
prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights
@ -637,16 +637,17 @@ class DbrxLayer(nn.Module):
class DbrxModel(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix="transformer.wte", weights=weights
prefix=f"{prefix}.wte", weights=weights
)
self.layers = nn.ModuleList(
[
DbrxLayer(
prefix,
layer_id,
config,
weights,
@ -655,7 +656,7 @@ class DbrxModel(torch.nn.Module):
]
)
self.norm = FastLayerNorm.load_no_bias(
prefix="transformer.norm_f", weights=weights, eps=1e-5
prefix=f"{prefix}.norm_f", weights=weights, eps=1e-5
)
self.head_size = self.layers[0].attn.self_attn.head_size
@ -702,9 +703,14 @@ class DbrxModel(torch.nn.Module):
class FlashDbrxForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.model = DbrxModel(config, weights)
self.lm_head = SpeculativeHead.load(
config,

View File

@ -102,7 +102,7 @@ class Gemma2Config(PretrainedConfig):
class Gemma2FastRMSNorm(FastRMSNorm):
@classmethod
def load(cls, prefix, weights, eps=1e-6):
def load(cls, prefix: str, weights, eps=1e-6):
dtype = weights.dtype
weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1
@ -123,7 +123,7 @@ class Gemma2FastRMSNorm(FastRMSNorm):
return hidden_states.to(self.dtype), residual
def load_attention(config, prefix, weights):
def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
@ -305,7 +305,7 @@ class Gemma2MLP(nn.Module):
class FlashGemma2Layer(nn.Module):
def __init__(self, prefix, config, weights, causal: bool, is_sliding: bool):
def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
super().__init__()
self.self_attn = FlashGemma2Attention(
prefix=f"{prefix}.self_attn",
@ -376,7 +376,7 @@ class FlashGemma2Layer(nn.Module):
class FlashGemma2Model(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool):
def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__()
process_group = weights.process_group
@ -442,7 +442,7 @@ class FlashGemma2Model(torch.nn.Module):
class FlashGemma2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, *, causal: bool = True):
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
super().__init__()
embed_norm = config.hidden_size**0.5

View File

@ -102,7 +102,7 @@ class GemmaConfig(PretrainedConfig):
class GemmaFastRMSNorm(FastRMSNorm):
@classmethod
def load(cls, prefix, weights, eps=1e-6):
def load(cls, prefix: str, weights, eps=1e-6):
dtype = weights.dtype
weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1
@ -123,7 +123,7 @@ class GemmaFastRMSNorm(FastRMSNorm):
return hidden_states.to(self.dtype), residual
def load_attention(config, prefix, weights):
def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
@ -261,7 +261,7 @@ class FlashGemmaAttention(torch.nn.Module):
class GemmaMLP(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
@ -299,7 +299,7 @@ class GemmaMLP(nn.Module):
class FlashGemmaLayer(nn.Module):
def __init__(self, prefix, config, weights, causal: bool):
def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__()
self.self_attn = FlashGemmaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
@ -354,7 +354,7 @@ class FlashGemmaLayer(nn.Module):
class FlashGemmaModel(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool):
def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__()
process_group = weights.process_group
@ -419,7 +419,7 @@ class FlashGemmaModel(torch.nn.Module):
class FlashGemmaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, *, causal: bool = True):
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
super().__init__()
embed_norm = config.hidden_size**0.5

View File

@ -261,7 +261,7 @@ class FlashGPT2Attention(torch.nn.Module):
class GPT2MLP(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
act = config.activation_function
self.act = (
@ -298,7 +298,7 @@ class GPT2MLP(nn.Module):
class FlashGPT2Layer(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.self_attn = FlashGPT2Attention(
prefix=f"{prefix}.attn", config=config, weights=weights
@ -350,7 +350,7 @@ class FlashGPT2Layer(nn.Module):
class FlashGPT2Model(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
process_group = weights.process_group
@ -414,7 +414,7 @@ class FlashGPT2Model(torch.nn.Module):
class FlashGPT2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(

View File

@ -54,7 +54,7 @@ if SYSTEM == "rocm":
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
def load_attention(config, prefix, weights, layer_id):
def load_attention(config, prefix: str, weights, layer_id):
# Only defined in granite.
bias = getattr(config, "attention_bias", False)
head_size = config.hidden_size // config.num_attention_heads
@ -467,7 +467,7 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(

View File

@ -248,7 +248,7 @@ class MistralAttention(torch.nn.Module):
class MistralMLP(nn.Module):
def __init__(self, prefix, config, weights, layer_id):
def __init__(self, prefix: str, config, weights, layer_id):
super().__init__()
self.hidden_act = config.hidden_act
self.act = (
@ -328,7 +328,7 @@ class MistralMLP(nn.Module):
class MistralLayer(nn.Module):
def __init__(self, prefix, config, weights, layer_id):
def __init__(self, prefix: str, config, weights, layer_id):
super().__init__()
self.self_attn = MistralAttention(
prefix=f"{prefix}.self_attn",
@ -392,7 +392,7 @@ class MistralLayer(nn.Module):
class MistralModel(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
process_group = weights.process_group
@ -462,7 +462,7 @@ class MistralModel(torch.nn.Module):
class FlashMistralForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, name=None):
def __init__(self, prefix: str, config, weights, name=None):
if name is None:
name = "model"
super().__init__()

View File

@ -116,7 +116,7 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor:
return x.view(1) if len(x.size()) == 0 else x
def load_attention(config, prefix, weights):
def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
@ -155,7 +155,7 @@ def _load_gqa(config, prefix: str, weights):
)
def _load_experts(config, prefix, mat, weights):
def _load_experts(config, prefix: str, mat, weights):
if config.quantize is not None:
raise NotImplementedError("Mixtral does not support weight quantization yet.")
@ -475,7 +475,7 @@ class DenseMoE(nn.Module):
class MixtralLayer(nn.Module):
def __init__(self, prefix, layer_id, config, weights):
def __init__(self, prefix: str, layer_id, config, weights):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
@ -536,7 +536,7 @@ class MixtralLayer(nn.Module):
class MixtralModel(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
@ -610,7 +610,7 @@ class MixtralModel(torch.nn.Module):
class FlashMixtralForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.model = MixtralModel(prefix, config, weights)

View File

@ -305,12 +305,12 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
self.config = config
self.embed_in = TensorParallelEmbedding(
prefix="gpt_neox.embed_in", weights=weights
prefix=f"{prefix}.embed_in", weights=weights
)
self.layers = nn.ModuleList(
@ -320,7 +320,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
]
)
self.final_layer_norm = FastLayerNorm.load(
prefix="gpt_neox.final_layer_norm",
prefix=f"{prefix}.final_layer_norm",
weights=weights,
eps=config.layer_norm_eps,
)
@ -370,9 +370,15 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix, config, weights):
super().__init__(config)
self.gpt_neox = FlashGPTNeoXModel(config, weights)
if not prefix:
prefix = "gpt_neox"
else:
prefix = f"{prefix}.gpt_neox"
self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights)
self.embed_out = SpeculativeHead.load(
config, prefix="embed_out", weights=weights

View File

@ -258,9 +258,9 @@ class PhiMLP(nn.Module):
class FlashPhiLayer(nn.Module):
def __init__(self, layer_id, config, weights):
def __init__(self, prefix: str, layer_id, config, weights):
super().__init__()
prefix = f"model.layers.{layer_id}"
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = FlashPhiAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
@ -307,18 +307,19 @@ class FlashPhiLayer(nn.Module):
class FlashPhiModel(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
FlashPhiLayer(
prefix,
layer_id,
config,
weights,
@ -378,10 +379,15 @@ class FlashPhiModel(torch.nn.Module):
class FlashPhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.model = FlashPhiModel(config, weights)
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = FlashPhiModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head",

View File

@ -203,9 +203,9 @@ class Qwen2MLP(nn.Module):
class Qwen2Layer(nn.Module):
def __init__(self, layer_id, config, weights):
def __init__(self, prefix, layer_id, config, weights):
super().__init__()
prefix = f"model.layers.{layer_id}"
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = Qwen2Attention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
@ -260,17 +260,18 @@ class Qwen2Layer(nn.Module):
class Qwen2Model(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
Qwen2Layer(
prefix,
layer_id,
config,
weights,
@ -279,7 +280,7 @@ class Qwen2Model(torch.nn.Module):
]
)
self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.gradient_checkpointing = False
@ -331,10 +332,15 @@ class Qwen2Model(torch.nn.Module):
class Qwen2ForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.model = Qwen2Model(config, weights)
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = Qwen2Model(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head",

View File

@ -127,7 +127,7 @@ class FlashRWAttention(torch.nn.Module):
def __init__(
self,
config,
prefix,
prefix: str,
weights,
):
super().__init__()
@ -236,7 +236,7 @@ class FlashRWLargeAttention(torch.nn.Module):
def __init__(
self,
config,
prefix,
prefix: str,
weights,
):
super().__init__()
@ -358,7 +358,7 @@ class FlashRWLargeAttention(torch.nn.Module):
class FlashMLP(nn.Module):
def __init__(self, config, prefix, weights):
def __init__(self, config, prefix: str, weights):
super().__init__()
self.act = torch.nn.functional.gelu
@ -380,6 +380,7 @@ class FlashRWLayer(nn.Module):
def __init__(
self,
layer_id,
prefix: str,
config,
weights,
):
@ -388,7 +389,7 @@ class FlashRWLayer(nn.Module):
parallel_attn = config.parallel_attn
self.parallel_attn = parallel_attn
prefix = f"transformer.h.{layer_id}"
prefix = f"{prefix}.h.{layer_id}"
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
@ -479,7 +480,7 @@ class FlashRWLayer(nn.Module):
class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix, weights):
def __init__(self, config, prefix: str, weights):
super().__init__()
self.num_ln = config.num_ln_in_parallel_attn
@ -518,9 +519,9 @@ class FlashRWLayerNorm(nn.Module):
class FlashRWLargeLayer(nn.Module):
def __init__(self, layer_id, config, weights):
def __init__(self, layer_id, prefix: str, config, weights):
super().__init__()
prefix = f"transformer.h.{layer_id}"
prefix = f"{prefix}.h.{layer_id}"
self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
@ -580,18 +581,18 @@ class FlashRWPreTrainedModel(PreTrainedModel):
class FlashRWModel(FlashRWPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
self.config = config
self.word_embeddings = TensorParallelEmbedding(
prefix="transformer.word_embeddings", weights=weights
prefix=f"{prefix}.word_embeddings", weights=weights
)
if config.new_decoder_architecture:
self.h = nn.ModuleList(
[
FlashRWLargeLayer(layer_id, config, weights)
FlashRWLargeLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
@ -599,14 +600,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
else:
self.h = nn.ModuleList(
[
FlashRWLayer(layer_id, config, weights)
FlashRWLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = self.h[0].self_attention.num_heads_kv
self.ln_f = FastLayerNorm.load(
prefix="transformer.ln_f",
prefix=f"{prefix}.ln_f",
weights=weights,
eps=config.layer_norm_epsilon,
)
@ -653,10 +654,15 @@ class FlashRWModel(FlashRWPreTrainedModel):
class FlashRWForCausalLM(FlashRWPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
self.transformer = FlashRWModel(config, weights)
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.transformer = FlashRWModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)

View File

@ -346,9 +346,9 @@ class MLP(nn.Module):
class Block(nn.Module):
def __init__(self, layer_id, config, weights):
def __init__(self, prefix: str, layer_id, config, weights):
super().__init__()
prefix = f"transformer.h.{layer_id}"
prefix = f"{prefix}.h.{layer_id}"
self.ln_1 = FastLayerNorm.load(
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
)
@ -396,18 +396,18 @@ class Block(nn.Module):
class FlashSantacoderModel(nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.config = config
self.process_group = weights.process_group
self.wte = TensorParallelEmbedding(
prefix="transformer.wte",
prefix=f"{prefix}.wte",
weights=weights,
reduce=False,
)
self.wpe = TensorParallelEmbedding(
prefix="transformer.wpe",
prefix=f"{prefix}.wpe",
weights=weights,
reduce=False,
)
@ -415,6 +415,7 @@ class FlashSantacoderModel(nn.Module):
self.layers = nn.ModuleList(
[
Block(
prefix,
layer_id,
config,
weights,
@ -466,10 +467,16 @@ class FlashSantacoderModel(nn.Module):
class FlashSantacoderForCausalLM(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
config.transpose = config.architectures[0].startswith("GPT2")
self.model = FlashSantacoderModel(config, weights)
self.model = FlashSantacoderModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights
config, prefix=f"{prefix}.wte", weights=weights
)
def forward(

View File

@ -783,7 +783,7 @@ class MPTPreTrainedModel(PreTrainedModel):
class MPTModel(MPTPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
# config._validate_config()
super().__init__(config)
self.world_size = weights.process_group.size()
@ -809,13 +809,13 @@ class MPTModel(MPTPreTrainedModel):
f"Requested norm type ({config.norm_type}) is not implemented within this repo."
)
self.wte = TensorParallelEmbedding("transformer.wte", weights)
self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights)
if not self.alibi:
self.wpe = TensorParallelEmbedding("transformer.wpe", weights)
self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights)
self.blocks = nn.ModuleList(
[
MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights)
MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights)
for i in range(config.n_layers)
]
)
@ -1085,13 +1085,19 @@ class MPTModel(MPTPreTrainedModel):
class MPTForCausalLM(MPTPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
if not config.tie_word_embeddings:
raise ValueError("MPTForCausalLM only supports tied word embeddings")
self.transformer = MPTModel(config, weights)
self.transformer = MPTModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights
config, prefix=f"{prefix}.wte", weights=weights
)
self.logit_scale = None
if config.logit_scale is not None:

View File

@ -404,24 +404,24 @@ class GPTNeoXMLP(nn.Module):
class GPTNeoXLayer(nn.Module):
def __init__(self, layer_id, config, weights):
def __init__(self, layer_id, prefix: str, config, weights):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm.load(
prefix=f"gpt_neox.layers.{layer_id}.input_layernorm",
prefix=f"{prefix}.layers.{layer_id}.input_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm",
prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.attention = GPTNeoXAttention(
config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights
config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights
)
self.mlp = GPTNeoXMLP(
config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights
config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights
)
def forward(
@ -472,23 +472,23 @@ class GPTNeoXLayer(nn.Module):
class GPTNeoXModel(GPTNeoXPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.embed_in = TensorParallelEmbedding(
prefix="gpt_neox.embed_in", weights=weights
prefix=f"{prefix}.embed_in", weights=weights
)
self.layers = nn.ModuleList(
[
GPTNeoXLayer(layer_id, config, weights)
GPTNeoXLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.final_layer_norm = nn.LayerNorm.load(
prefix="gpt_neox.final_layer_norm",
prefix=f"{prefix}.final_layer_norm",
weights=weights,
eps=config.layer_norm_eps,
)
@ -640,9 +640,15 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
self.gpt_neox = GPTNeoXModel(config, weights)
if not prefix:
prefix = "gpt_neox"
else:
prefix = f"{prefix}.gpt_neox"
self.gpt_neox = GPTNeoXModel(prefix, config, weights)
self.embed_out = SpeculativeHead.load(
config, prefix="embed_out", weights=weights
)

View File

@ -94,11 +94,11 @@ class OPTLearnedPositionalEmbedding(nn.Module):
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, weights):
def __init__(self, prefix: str, weights):
super().__init__()
self.offset = 2
self.weight = nn.Parameter(
weights.get_tensor("model.decoder.embed_positions.weight")
weights.get_tensor(f"{prefix}.decoder.embed_positions.weight")
)
def forward(
@ -311,11 +311,11 @@ class OPTAttention(nn.Module):
class OPTDecoderLayer(nn.Module):
def __init__(self, layer_id: int, config: OPTConfig, weights):
def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights):
super().__init__()
self.process_group = weights.process_group
self.hidden_size = config.hidden_size
prefix = f"model.decoder.layers.{layer_id}"
prefix = f"{prefix}.decoder.layers.{layer_id}"
self.self_attn = OPTAttention(
config,
prefix=f"{prefix}.self_attn",
@ -429,7 +429,7 @@ class OPTPreTrainedModel(PreTrainedModel):
class OPTDecoder(OPTPreTrainedModel):
def __init__(self, config: OPTConfig, weights):
def __init__(self, prefix: str, config: OPTConfig, weights):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.layerdrop
@ -438,20 +438,26 @@ class OPTDecoder(OPTPreTrainedModel):
self.vocab_size = config.vocab_size
self.embed_tokens = TensorParallelEmbedding(
prefix="model.decoder.embed_tokens", weights=weights
prefix=f"{prefix}.decoder.embed_tokens", weights=weights
)
self.embed_positions = OPTLearnedPositionalEmbedding(weights)
self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = FastLinear.load(
config, prefix="model.decoder.project_out", weights=weights, bias=False
config,
prefix=f"{prefix}.decoder.project_out",
weights=weights,
bias=False,
)
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = FastLinear.load(
config, prefix="model.decoder.project_in", weights=weights, bias=False
config,
prefix=f"{prefix}.decoder.project_in",
weights=weights,
bias=False,
)
else:
self.project_in = None
@ -461,14 +467,14 @@ class OPTDecoder(OPTPreTrainedModel):
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm.load(
prefix="model.decoder.final_layer_norm", weights=weights, eps=EPS
prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS
)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList(
[
OPTDecoderLayer(layer_id, config, weights)
OPTDecoderLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
@ -686,9 +692,9 @@ class OPTDecoder(OPTPreTrainedModel):
class OPTModel(OPTPreTrainedModel):
def __init__(self, config: OPTConfig, weights):
def __init__(self, prefix: str, config: OPTConfig, weights):
super().__init__(config)
self.decoder = OPTDecoder(config, weights)
self.decoder = OPTDecoder(prefix, config, weights)
# Initialize weights and apply final processing
def forward(
@ -743,13 +749,18 @@ class OPTModel(OPTPreTrainedModel):
class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix, config, weights):
super().__init__(config)
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = OPTModel(config, weights)
self.lm_head = SpeculativeHead.load(
config, prefix="model.decoder.embed_tokens", weights=weights
config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights
)
def forward(

View File

@ -248,16 +248,16 @@ class PhiBlock(nn.Module):
# PhiModel implements the embedding layer and the transformer blocks.
class PhiModel(nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.tp_rank = weights.process_group.rank()
self.tp_world_size = weights.process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="transformer.embd.wte", weights=weights
prefix=f"{prefix}.embd.wte", weights=weights
)
self.blocks = nn.ModuleList(
[
PhiBlock(f"transformer.h.{layer_id}", config, weights)
PhiBlock(f"{prefix}.h.{layer_id}", config, weights)
for layer_id in range(config.n_layer)
]
)
@ -289,9 +289,15 @@ class PhiModel(nn.Module):
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
class PhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.model = PhiModel(config, weights)
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.model = PhiModel(prefix, config, weights)
self.lm_head = PhiCausalLMHead(config, weights)
def forward(

View File

@ -878,10 +878,6 @@ class FlashCausalLM(Model):
)
config.quantize = quantize
config.speculator = speculator
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
torch.distributed.barrier(group=self.process_group)
@ -900,13 +896,22 @@ class FlashCausalLM(Model):
text_config = getattr(config, "text_config", None)
if text_config is not None:
config = text_config
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
self.num_layers = config.num_hidden_layers
# Validation is done in the model itself
if num_kv_heads is None:
num_kv_heads = getattr(config, "num_key_value_heads", None)
# Order is important here.
for attr in ["num_key_value_heads", "num_key_value_heads", "n_head"]:
num_kv_heads = getattr(config, "num_attention_heads", None)
if num_kv_heads is not None:
break
if num_kv_heads is None:
# Final overide for GPT2
num_kv_heads = config.n_head
raise ValueError("Cannot get the number of key/value heads")
self.num_kv_heads = num_kv_heads // self.process_group.size()
self.head_size = config.hidden_size // config.num_attention_heads