diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index d039e1e7..b9205bd9 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -293,6 +293,7 @@ else: max_s, softmax_scale, window_size_left=-1, + causal=None, softcap=None, ): if window_size_left != -1: @@ -328,22 +329,69 @@ else: .reshape(original_shape[0], -1, original_shape[2]) ) + original_shape = q.shape + original_type = q.dtype + + # ensure type is a float16 and that the last dimension is 128 or less + # as required by the flash attention kernel [flash-attention/csrc/flash_attn/fmha_api.cpp:246] + chunk_size = 128 + if q.dtype == torch.bfloat16: + q = q.to(torch.float16) + k = k.to(torch.float16) + v = v.to(torch.float16) + + # calculate the number of full chunks and the size of the last chunk + num_full_chunks = original_shape[-1] // chunk_size + last_chunk_size = original_shape[-1] % chunk_size + + # preallocate the output tensor out = torch.empty_like(q) - flash_attn_cuda.fwd( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - 0.0, - softmax_scale, - False, - True, - False, - 0, - None, - ) + + # process full chunks + for i in range(num_full_chunks): + start = i * chunk_size + end = start + chunk_size + flash_attn_cuda.fwd( + q[..., start:end], + k[..., start:end], + v[..., start:end], + out[..., start:end], + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + 0, + None, + ) + + # process the last chunk if it exists + if last_chunk_size > 0: + start = num_full_chunks * chunk_size + flash_attn_cuda.fwd( + q[..., start:], + k[..., start:], + v[..., start:], + out[..., start:], + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + 0, + None, + ) + + # make sure it's the original type + if out.dtype != original_type: + out = out.to(original_type) + return out diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 960b426b..3ed83891 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -500,14 +500,15 @@ def get_model( if max_input_tokens is not None and max_input_tokens <= sliding_window: sliding_window = -1 - if ( - (sliding_window is not None and sliding_window != -1) - and not SUPPORTS_WINDOWING - and max_input_tokens > sliding_window - ): - raise ValueError( - f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." - ) + should_use_sliding_window = ( + sliding_window is not None and sliding_window != -1 and SUPPORTS_WINDOWING + ) + + if should_use_sliding_window: + if max_input_tokens is not None and max_input_tokens > sliding_window: + raise ValueError( + f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." + ) if model_type == DEEPSEEK_V2: if FLASH_ATTENTION: