This commit is contained in:
AUTOMATIC 2023-05-11 07:45:05 +03:00
parent c9e5b92106
commit e334758ec2
1 changed files with 4 additions and 12 deletions

View File

@ -201,23 +201,15 @@ def efficient_dot_product_attention(
key=key,
value=value,
)
# slices of res tensor are mutable, modifications made
# to the slices will affect the original tensor.
# if output of compute_query_chunk_attn function has same number of
# dimensions as input query tensor, we initialize tensor like this:
num_query_chunks = int(np.ceil(q_tokens / query_chunk_size))
query_shape = get_query_chunk(0).shape
res_shape = (query_shape[0], query_shape[1] * num_query_chunks, *query_shape[2:])
res_dtype = get_query_chunk(0).dtype
res = torch.zeros(res_shape, dtype=res_dtype)
for i in range(num_query_chunks):
res = torch.zeros_like(query)
for i in range(math.ceil(q_tokens / query_chunk_size)):
attn_scores = compute_query_chunk_attn(
query=get_query_chunk(i * query_chunk_size),
key=key,
value=value,
)
res[:, i * query_chunk_size:(i + 1) * query_chunk_size, :] = attn_scores
res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores
return res