Hotfix: fix MPT after recent refactor (#2257)
This commit is contained in:
parent
18db78f295
commit
3b41e93a09
|
@ -492,7 +492,7 @@ class CausalLMBatch(Batch):
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CausalLMBatchKeysLast(Batch):
|
class CausalLMBatchKeysLast(CausalLMBatch):
|
||||||
keys_head_dim_last: bool = False
|
keys_head_dim_last: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@ -544,7 +544,12 @@ class CausalLM(Model):
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.speculator = speculator
|
config.speculator = speculator
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
|
if config.pad_token_id is not None:
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
|
elif config.eos_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = config.eos_token_id
|
||||||
|
elif tokenizer.eos_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
weights_loader = get_loader(
|
weights_loader = get_loader(
|
||||||
|
|
|
@ -337,17 +337,17 @@ class MultiheadAttention(nn.Module):
|
||||||
weights,
|
weights,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
attn_impl = config.attn_config["attn_impl"]
|
attn_impl = config.attn_config.attn_impl
|
||||||
self.attn_impl = config.attn_config["attn_impl"]
|
self.attn_impl = config.attn_config.attn_impl
|
||||||
self.clip_qkv = config.attn_config["clip_qkv"]
|
self.clip_qkv = config.attn_config.clip_qkv
|
||||||
self.qk_ln = config.attn_config["qk_ln"]
|
self.qk_ln = config.attn_config.qk_ln
|
||||||
self.d_model = config.d_model
|
self.d_model = config.d_model
|
||||||
d_model = config.d_model
|
d_model = config.d_model
|
||||||
self.n_heads = config.n_heads
|
self.n_heads = config.n_heads
|
||||||
self.softmax_scale = config.attn_config["softmax_scale"]
|
self.softmax_scale = config.attn_config.softmax_scale
|
||||||
if self.softmax_scale is None:
|
if self.softmax_scale is None:
|
||||||
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
||||||
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
self.attn_dropout_p = config.attn_config.attn_pdrop
|
||||||
|
|
||||||
if self.n_heads % weights.process_group.size() != 0:
|
if self.n_heads % weights.process_group.size() != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -430,17 +430,17 @@ class MultiQueryAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
attn_impl = config.attn_config["attn_impl"]
|
attn_impl = config.attn_config.attn_impl
|
||||||
self.attn_impl = config.attn_config["attn_impl"]
|
self.attn_impl = config.attn_config.attn_impl
|
||||||
self.clip_qkv = config.attn_config["clip_qkv"]
|
self.clip_qkv = config.attn_config.clip_qkv
|
||||||
self.qk_ln = config.attn_config["qk_ln"]
|
self.qk_ln = config.attn_config.qk_ln
|
||||||
self.d_model = config.d_model
|
self.d_model = config.d_model
|
||||||
d_model = config.d_model
|
d_model = config.d_model
|
||||||
self.n_heads = config.n_heads
|
self.n_heads = config.n_heads
|
||||||
self.softmax_scale = config.attn_config["softmax_scale"]
|
self.softmax_scale = config.attn_config.softmax_scale
|
||||||
if self.softmax_scale is None:
|
if self.softmax_scale is None:
|
||||||
self.softmax_scale = 1 / math.sqrt(self.head_dim)
|
self.softmax_scale = 1 / math.sqrt(self.head_dim)
|
||||||
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
self.attn_dropout_p = config.attn_config.attn_pdrop
|
||||||
# self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
|
# self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
|
||||||
self.Wqkv = TensorParallelColumnLinear.load(
|
self.Wqkv = TensorParallelColumnLinear.load(
|
||||||
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||||
|
@ -614,9 +614,9 @@ class MPTBlock(nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
if config.attn_config["attn_type"] != "multihead_attention":
|
if config.attn_config.attn_type != "multihead_attention":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"""Not implemented attn {config.attn_config["attn_type"]}"""
|
f"""Not implemented attn {config.attn_config.attn_type}"""
|
||||||
)
|
)
|
||||||
resid_pdrop = config.resid_pdrop
|
resid_pdrop = config.resid_pdrop
|
||||||
if config.no_bias:
|
if config.no_bias:
|
||||||
|
@ -789,11 +789,11 @@ class MPTModel(MPTPreTrainedModel):
|
||||||
self.world_size = weights.process_group.size()
|
self.world_size = weights.process_group.size()
|
||||||
self.rank = weights.process_group.rank()
|
self.rank = weights.process_group.rank()
|
||||||
self.n_heads = config.n_heads
|
self.n_heads = config.n_heads
|
||||||
self.attn_impl = config.attn_config["attn_impl"]
|
self.attn_impl = config.attn_config.attn_impl
|
||||||
self.prefix_lm = config.attn_config["prefix_lm"]
|
self.prefix_lm = config.attn_config.prefix_lm
|
||||||
self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
|
self.attn_uses_sequence_id = config.attn_config.attn_uses_sequence_id
|
||||||
self.alibi = config.attn_config["alibi"]
|
self.alibi = config.attn_config.alibi
|
||||||
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
self.alibi_bias_max = config.attn_config.alibi_bias_max
|
||||||
if config.init_device == "mixed":
|
if config.init_device == "mixed":
|
||||||
if dist.get_local_rank() == 0:
|
if dist.get_local_rank() == 0:
|
||||||
config.init_device = "cpu"
|
config.init_device = "cpu"
|
||||||
|
|
Loading…
Reference in New Issue