wip: debug gemma and flash
This commit is contained in:
parent
8dcc7d3f6b
commit
7bc16deb48
|
@ -293,6 +293,7 @@ else:
|
|||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=None,
|
||||
softcap=None,
|
||||
):
|
||||
if window_size_left != -1:
|
||||
|
@ -328,12 +329,33 @@ else:
|
|||
.reshape(original_shape[0], -1, original_shape[2])
|
||||
)
|
||||
|
||||
original_shape = q.shape
|
||||
original_type = q.dtype
|
||||
|
||||
# ensure type is a float16 and that the last dimension is 128 or less
|
||||
# as required by the flash attention kernel [flash-attention/csrc/flash_attn/fmha_api.cpp:246]
|
||||
chunk_size = 128
|
||||
if q.dtype == torch.bfloat16:
|
||||
q = q.to(torch.float16)
|
||||
k = k.to(torch.float16)
|
||||
v = v.to(torch.float16)
|
||||
|
||||
# calculate the number of full chunks and the size of the last chunk
|
||||
num_full_chunks = original_shape[-1] // chunk_size
|
||||
last_chunk_size = original_shape[-1] % chunk_size
|
||||
|
||||
# preallocate the output tensor
|
||||
out = torch.empty_like(q)
|
||||
|
||||
# process full chunks
|
||||
for i in range(num_full_chunks):
|
||||
start = i * chunk_size
|
||||
end = start + chunk_size
|
||||
flash_attn_cuda.fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
q[..., start:end],
|
||||
k[..., start:end],
|
||||
v[..., start:end],
|
||||
out[..., start:end],
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
|
@ -346,4 +368,30 @@ else:
|
|||
0,
|
||||
None,
|
||||
)
|
||||
|
||||
# process the last chunk if it exists
|
||||
if last_chunk_size > 0:
|
||||
start = num_full_chunks * chunk_size
|
||||
flash_attn_cuda.fwd(
|
||||
q[..., start:],
|
||||
k[..., start:],
|
||||
v[..., start:],
|
||||
out[..., start:],
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
|
||||
# make sure it's the original type
|
||||
if out.dtype != original_type:
|
||||
out = out.to(original_type)
|
||||
|
||||
return out
|
||||
|
|
|
@ -500,11 +500,12 @@ def get_model(
|
|||
if max_input_tokens is not None and max_input_tokens <= sliding_window:
|
||||
sliding_window = -1
|
||||
|
||||
if (
|
||||
(sliding_window is not None and sliding_window != -1)
|
||||
and not SUPPORTS_WINDOWING
|
||||
and max_input_tokens > sliding_window
|
||||
):
|
||||
should_use_sliding_window = (
|
||||
sliding_window is not None and sliding_window != -1 and SUPPORTS_WINDOWING
|
||||
)
|
||||
|
||||
if should_use_sliding_window:
|
||||
if max_input_tokens is not None and max_input_tokens > sliding_window:
|
||||
raise ValueError(
|
||||
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue