Hotfix: fix MPT after recent refactor (#2257)

This commit is contained in:
Daniël de Kok 2024-07-19 14:42:35 +02:00 committed by GitHub
parent 18db78f295
commit 3b41e93a09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 21 deletions

View File

@ -492,7 +492,7 @@ class CausalLMBatch(Batch):
@dataclass
class CausalLMBatchKeysLast(Batch):
class CausalLMBatchKeysLast(CausalLMBatch):
keys_head_dim_last: bool = False
@ -544,7 +544,12 @@ class CausalLM(Model):
config.quantize = quantize
config.speculator = speculator
if tokenizer.pad_token_id is None:
if config.pad_token_id is not None:
tokenizer.pad_token_id = config.pad_token_id
elif config.eos_token_id is not None:
tokenizer.pad_token_id = config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(

View File

@ -337,17 +337,17 @@ class MultiheadAttention(nn.Module):
weights,
):
super().__init__()
attn_impl = config.attn_config["attn_impl"]
self.attn_impl = config.attn_config["attn_impl"]
self.clip_qkv = config.attn_config["clip_qkv"]
self.qk_ln = config.attn_config["qk_ln"]
attn_impl = config.attn_config.attn_impl
self.attn_impl = config.attn_config.attn_impl
self.clip_qkv = config.attn_config.clip_qkv
self.qk_ln = config.attn_config.qk_ln
self.d_model = config.d_model
d_model = config.d_model
self.n_heads = config.n_heads
self.softmax_scale = config.attn_config["softmax_scale"]
self.softmax_scale = config.attn_config.softmax_scale
if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = config.attn_config["attn_pdrop"]
self.attn_dropout_p = config.attn_config.attn_pdrop
if self.n_heads % weights.process_group.size() != 0:
raise ValueError(
@ -430,17 +430,17 @@ class MultiQueryAttention(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
attn_impl = config.attn_config["attn_impl"]
self.attn_impl = config.attn_config["attn_impl"]
self.clip_qkv = config.attn_config["clip_qkv"]
self.qk_ln = config.attn_config["qk_ln"]
attn_impl = config.attn_config.attn_impl
self.attn_impl = config.attn_config.attn_impl
self.clip_qkv = config.attn_config.clip_qkv
self.qk_ln = config.attn_config.qk_ln
self.d_model = config.d_model
d_model = config.d_model
self.n_heads = config.n_heads
self.softmax_scale = config.attn_config["softmax_scale"]
self.softmax_scale = config.attn_config.softmax_scale
if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.head_dim)
self.attn_dropout_p = config.attn_config["attn_pdrop"]
self.attn_dropout_p = config.attn_config.attn_pdrop
# self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
self.Wqkv = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
@ -614,9 +614,9 @@ class MPTBlock(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.prefix = prefix
if config.attn_config["attn_type"] != "multihead_attention":
if config.attn_config.attn_type != "multihead_attention":
raise NotImplementedError(
f"""Not implemented attn {config.attn_config["attn_type"]}"""
f"""Not implemented attn {config.attn_config.attn_type}"""
)
resid_pdrop = config.resid_pdrop
if config.no_bias:
@ -789,11 +789,11 @@ class MPTModel(MPTPreTrainedModel):
self.world_size = weights.process_group.size()
self.rank = weights.process_group.rank()
self.n_heads = config.n_heads
self.attn_impl = config.attn_config["attn_impl"]
self.prefix_lm = config.attn_config["prefix_lm"]
self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
self.alibi = config.attn_config["alibi"]
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
self.attn_impl = config.attn_config.attn_impl
self.prefix_lm = config.attn_config.prefix_lm
self.attn_uses_sequence_id = config.attn_config.attn_uses_sequence_id
self.alibi = config.attn_config.alibi
self.alibi_bias_max = config.attn_config.alibi_bias_max
if config.init_device == "mixed":
if dist.get_local_rank() == 0:
config.init_device = "cpu"