Hotfix: fix MPT after recent refactor (#2257)
This commit is contained in:
parent
18db78f295
commit
3b41e93a09
|
@ -492,7 +492,7 @@ class CausalLMBatch(Batch):
|
|||
|
||||
|
||||
@dataclass
|
||||
class CausalLMBatchKeysLast(Batch):
|
||||
class CausalLMBatchKeysLast(CausalLMBatch):
|
||||
keys_head_dim_last: bool = False
|
||||
|
||||
|
||||
|
@ -544,7 +544,12 @@ class CausalLM(Model):
|
|||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
if config.pad_token_id is not None:
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
elif config.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = config.eos_token_id
|
||||
elif tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
weights_loader = get_loader(
|
||||
|
|
|
@ -337,17 +337,17 @@ class MultiheadAttention(nn.Module):
|
|||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
attn_impl = config.attn_config["attn_impl"]
|
||||
self.attn_impl = config.attn_config["attn_impl"]
|
||||
self.clip_qkv = config.attn_config["clip_qkv"]
|
||||
self.qk_ln = config.attn_config["qk_ln"]
|
||||
attn_impl = config.attn_config.attn_impl
|
||||
self.attn_impl = config.attn_config.attn_impl
|
||||
self.clip_qkv = config.attn_config.clip_qkv
|
||||
self.qk_ln = config.attn_config.qk_ln
|
||||
self.d_model = config.d_model
|
||||
d_model = config.d_model
|
||||
self.n_heads = config.n_heads
|
||||
self.softmax_scale = config.attn_config["softmax_scale"]
|
||||
self.softmax_scale = config.attn_config.softmax_scale
|
||||
if self.softmax_scale is None:
|
||||
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
||||
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
||||
self.attn_dropout_p = config.attn_config.attn_pdrop
|
||||
|
||||
if self.n_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
|
@ -430,17 +430,17 @@ class MultiQueryAttention(nn.Module):
|
|||
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
attn_impl = config.attn_config["attn_impl"]
|
||||
self.attn_impl = config.attn_config["attn_impl"]
|
||||
self.clip_qkv = config.attn_config["clip_qkv"]
|
||||
self.qk_ln = config.attn_config["qk_ln"]
|
||||
attn_impl = config.attn_config.attn_impl
|
||||
self.attn_impl = config.attn_config.attn_impl
|
||||
self.clip_qkv = config.attn_config.clip_qkv
|
||||
self.qk_ln = config.attn_config.qk_ln
|
||||
self.d_model = config.d_model
|
||||
d_model = config.d_model
|
||||
self.n_heads = config.n_heads
|
||||
self.softmax_scale = config.attn_config["softmax_scale"]
|
||||
self.softmax_scale = config.attn_config.softmax_scale
|
||||
if self.softmax_scale is None:
|
||||
self.softmax_scale = 1 / math.sqrt(self.head_dim)
|
||||
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
||||
self.attn_dropout_p = config.attn_config.attn_pdrop
|
||||
# self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
|
||||
self.Wqkv = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||
|
@ -614,9 +614,9 @@ class MPTBlock(nn.Module):
|
|||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.prefix = prefix
|
||||
if config.attn_config["attn_type"] != "multihead_attention":
|
||||
if config.attn_config.attn_type != "multihead_attention":
|
||||
raise NotImplementedError(
|
||||
f"""Not implemented attn {config.attn_config["attn_type"]}"""
|
||||
f"""Not implemented attn {config.attn_config.attn_type}"""
|
||||
)
|
||||
resid_pdrop = config.resid_pdrop
|
||||
if config.no_bias:
|
||||
|
@ -789,11 +789,11 @@ class MPTModel(MPTPreTrainedModel):
|
|||
self.world_size = weights.process_group.size()
|
||||
self.rank = weights.process_group.rank()
|
||||
self.n_heads = config.n_heads
|
||||
self.attn_impl = config.attn_config["attn_impl"]
|
||||
self.prefix_lm = config.attn_config["prefix_lm"]
|
||||
self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
|
||||
self.alibi = config.attn_config["alibi"]
|
||||
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
||||
self.attn_impl = config.attn_config.attn_impl
|
||||
self.prefix_lm = config.attn_config.prefix_lm
|
||||
self.attn_uses_sequence_id = config.attn_config.attn_uses_sequence_id
|
||||
self.alibi = config.attn_config.alibi
|
||||
self.alibi_bias_max = config.attn_config.alibi_bias_max
|
||||
if config.init_device == "mixed":
|
||||
if dist.get_local_rank() == 0:
|
||||
config.init_device = "cpu"
|
||||
|
|
Loading…
Reference in New Issue