Fix: Replace view() with reshape() in neox_modeling.py to resolve RuntimeError (#1155)

This commit is contained in:
Mario928 2023-10-19 15:24:26 +05:30 committed by GitHub
parent 7402a355dc
commit 9179605e1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -283,10 +283,10 @@ class GPTNeoXAttention(nn.Module):
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2)
query = query.view(
query = query.reshape(
batch_size * num_attention_heads, query_length, attn_head_size
)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size)
attn_scores = torch.zeros(
1,
dtype=query.dtype,