wip: debug gemma and flash
This commit is contained in:
parent
8dcc7d3f6b
commit
7bc16deb48
|
@ -293,6 +293,7 @@ else:
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
|
causal=None,
|
||||||
softcap=None,
|
softcap=None,
|
||||||
):
|
):
|
||||||
if window_size_left != -1:
|
if window_size_left != -1:
|
||||||
|
@ -328,22 +329,69 @@ else:
|
||||||
.reshape(original_shape[0], -1, original_shape[2])
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
original_shape = q.shape
|
||||||
|
original_type = q.dtype
|
||||||
|
|
||||||
|
# ensure type is a float16 and that the last dimension is 128 or less
|
||||||
|
# as required by the flash attention kernel [flash-attention/csrc/flash_attn/fmha_api.cpp:246]
|
||||||
|
chunk_size = 128
|
||||||
|
if q.dtype == torch.bfloat16:
|
||||||
|
q = q.to(torch.float16)
|
||||||
|
k = k.to(torch.float16)
|
||||||
|
v = v.to(torch.float16)
|
||||||
|
|
||||||
|
# calculate the number of full chunks and the size of the last chunk
|
||||||
|
num_full_chunks = original_shape[-1] // chunk_size
|
||||||
|
last_chunk_size = original_shape[-1] % chunk_size
|
||||||
|
|
||||||
|
# preallocate the output tensor
|
||||||
out = torch.empty_like(q)
|
out = torch.empty_like(q)
|
||||||
flash_attn_cuda.fwd(
|
|
||||||
q,
|
# process full chunks
|
||||||
k,
|
for i in range(num_full_chunks):
|
||||||
v,
|
start = i * chunk_size
|
||||||
out,
|
end = start + chunk_size
|
||||||
cu_seqlens,
|
flash_attn_cuda.fwd(
|
||||||
cu_seqlens,
|
q[..., start:end],
|
||||||
max_s,
|
k[..., start:end],
|
||||||
max_s,
|
v[..., start:end],
|
||||||
0.0,
|
out[..., start:end],
|
||||||
softmax_scale,
|
cu_seqlens,
|
||||||
False,
|
cu_seqlens,
|
||||||
True,
|
max_s,
|
||||||
False,
|
max_s,
|
||||||
0,
|
0.0,
|
||||||
None,
|
softmax_scale,
|
||||||
)
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# process the last chunk if it exists
|
||||||
|
if last_chunk_size > 0:
|
||||||
|
start = num_full_chunks * chunk_size
|
||||||
|
flash_attn_cuda.fwd(
|
||||||
|
q[..., start:],
|
||||||
|
k[..., start:],
|
||||||
|
v[..., start:],
|
||||||
|
out[..., start:],
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# make sure it's the original type
|
||||||
|
if out.dtype != original_type:
|
||||||
|
out = out.to(original_type)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
|
@ -500,14 +500,15 @@ def get_model(
|
||||||
if max_input_tokens is not None and max_input_tokens <= sliding_window:
|
if max_input_tokens is not None and max_input_tokens <= sliding_window:
|
||||||
sliding_window = -1
|
sliding_window = -1
|
||||||
|
|
||||||
if (
|
should_use_sliding_window = (
|
||||||
(sliding_window is not None and sliding_window != -1)
|
sliding_window is not None and sliding_window != -1 and SUPPORTS_WINDOWING
|
||||||
and not SUPPORTS_WINDOWING
|
)
|
||||||
and max_input_tokens > sliding_window
|
|
||||||
):
|
if should_use_sliding_window:
|
||||||
raise ValueError(
|
if max_input_tokens is not None and max_input_tokens > sliding_window:
|
||||||
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
raise ValueError(
|
||||||
)
|
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
||||||
|
)
|
||||||
|
|
||||||
if model_type == DEEPSEEK_V2:
|
if model_type == DEEPSEEK_V2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
|
Loading…
Reference in New Issue