wip: debug gemma and flash

This commit is contained in:
drbh 2024-08-09 23:08:54 +00:00
parent 8dcc7d3f6b
commit 7bc16deb48
2 changed files with 74 additions and 25 deletions

View File

@ -293,6 +293,7 @@ else:
max_s, max_s,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=None,
softcap=None, softcap=None,
): ):
if window_size_left != -1: if window_size_left != -1:
@ -328,12 +329,33 @@ else:
.reshape(original_shape[0], -1, original_shape[2]) .reshape(original_shape[0], -1, original_shape[2])
) )
original_shape = q.shape
original_type = q.dtype
# ensure type is a float16 and that the last dimension is 128 or less
# as required by the flash attention kernel [flash-attention/csrc/flash_attn/fmha_api.cpp:246]
chunk_size = 128
if q.dtype == torch.bfloat16:
q = q.to(torch.float16)
k = k.to(torch.float16)
v = v.to(torch.float16)
# calculate the number of full chunks and the size of the last chunk
num_full_chunks = original_shape[-1] // chunk_size
last_chunk_size = original_shape[-1] % chunk_size
# preallocate the output tensor
out = torch.empty_like(q) out = torch.empty_like(q)
# process full chunks
for i in range(num_full_chunks):
start = i * chunk_size
end = start + chunk_size
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
q, q[..., start:end],
k, k[..., start:end],
v, v[..., start:end],
out, out[..., start:end],
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
max_s, max_s,
@ -346,4 +368,30 @@ else:
0, 0,
None, None,
) )
# process the last chunk if it exists
if last_chunk_size > 0:
start = num_full_chunks * chunk_size
flash_attn_cuda.fwd(
q[..., start:],
k[..., start:],
v[..., start:],
out[..., start:],
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
0,
None,
)
# make sure it's the original type
if out.dtype != original_type:
out = out.to(original_type)
return out return out

View File

@ -500,11 +500,12 @@ def get_model(
if max_input_tokens is not None and max_input_tokens <= sliding_window: if max_input_tokens is not None and max_input_tokens <= sliding_window:
sliding_window = -1 sliding_window = -1
if ( should_use_sliding_window = (
(sliding_window is not None and sliding_window != -1) sliding_window is not None and sliding_window != -1 and SUPPORTS_WINDOWING
and not SUPPORTS_WINDOWING )
and max_input_tokens > sliding_window
): if should_use_sliding_window:
if max_input_tokens is not None and max_input_tokens > sliding_window:
raise ValueError( raise ValueError(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
) )