wip: debug gemma and flash

This commit is contained in:
drbh 2024-08-09 23:08:54 +00:00
parent 8dcc7d3f6b
commit 7bc16deb48
2 changed files with 74 additions and 25 deletions

View File

@ -293,6 +293,7 @@ else:
max_s,
softmax_scale,
window_size_left=-1,
causal=None,
softcap=None,
):
if window_size_left != -1:
@ -328,22 +329,69 @@ else:
.reshape(original_shape[0], -1, original_shape[2])
)
original_shape = q.shape
original_type = q.dtype
# ensure type is a float16 and that the last dimension is 128 or less
# as required by the flash attention kernel [flash-attention/csrc/flash_attn/fmha_api.cpp:246]
chunk_size = 128
if q.dtype == torch.bfloat16:
q = q.to(torch.float16)
k = k.to(torch.float16)
v = v.to(torch.float16)
# calculate the number of full chunks and the size of the last chunk
num_full_chunks = original_shape[-1] // chunk_size
last_chunk_size = original_shape[-1] % chunk_size
# preallocate the output tensor
out = torch.empty_like(q)
flash_attn_cuda.fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
0,
None,
)
# process full chunks
for i in range(num_full_chunks):
start = i * chunk_size
end = start + chunk_size
flash_attn_cuda.fwd(
q[..., start:end],
k[..., start:end],
v[..., start:end],
out[..., start:end],
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
0,
None,
)
# process the last chunk if it exists
if last_chunk_size > 0:
start = num_full_chunks * chunk_size
flash_attn_cuda.fwd(
q[..., start:],
k[..., start:],
v[..., start:],
out[..., start:],
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
0,
None,
)
# make sure it's the original type
if out.dtype != original_type:
out = out.to(original_type)
return out

View File

@ -500,14 +500,15 @@ def get_model(
if max_input_tokens is not None and max_input_tokens <= sliding_window:
sliding_window = -1
if (
(sliding_window is not None and sliding_window != -1)
and not SUPPORTS_WINDOWING
and max_input_tokens > sliding_window
):
raise ValueError(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
)
should_use_sliding_window = (
sliding_window is not None and sliding_window != -1 and SUPPORTS_WINDOWING
)
if should_use_sliding_window:
if max_input_tokens is not None and max_input_tokens > sliding_window:
raise ValueError(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
)
if model_type == DEEPSEEK_V2:
if FLASH_ATTENTION: