From 1028996fb380f07ebb2a9de1d2795e176f845c59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Sat, 28 Sep 2024 18:41:41 +0200 Subject: [PATCH] flashinfer: pass window size and dtype (#2574) --- .../layers/attention/flashinfer.py | 19 +++++++++++++------ .../models/flash_causal_lm.py | 4 ++++ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index 5f8954ea..d603c6f5 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state( num_kv_heads: int, head_size: int, page_size: int, - query_dtype: str = "float16", + dtype: torch.dtype, + window_left: int, ): """ Context manager to set the active flashinfer prefill state to the given @@ -90,8 +91,9 @@ def use_prefill_with_paged_kv_state( num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_size, - q_data_type=query_dtype, + q_data_type=dtype, page_size=page_size, + window_left=window_left, ) yield finally: @@ -119,7 +121,8 @@ def use_prefill_state( num_heads: int, num_kv_heads: int, head_size: int, - query_dtype: str = "float16", + dtype: torch.dtype, + window_left: int, ): """ Context manager to set the active flashinfer prefill state to the given @@ -135,7 +138,8 @@ def use_prefill_state( num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_size, - q_data_type=query_dtype, + q_data_type=dtype, + window_left=window_left, ) yield finally: @@ -200,7 +204,8 @@ def use_decode_state( num_kv_heads: int, head_size: int, page_size: int, - query_dtype: str = "float16", + dtype: torch.dtype, + window_left: int, ): """ Context manager to set the active flashinfer decoding state to the given @@ -235,7 +240,9 @@ def use_decode_state( num_kv_heads=num_kv_heads, head_dim=head_size, page_size=page_size, - q_data_type=query_dtype, + data_type=dtype, + q_data_type=dtype, + window_left=window_left, ) yield finally: diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a2834962..57582ebc 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1960,6 +1960,8 @@ class FlashCausalLM(Model): num_kv_heads=self.num_kv_heads, head_size=self.head_size, page_size=BLOCK_SIZE, + dtype=self.dtype, + window_left=self.sliding_window, ) else: assert input_lengths_tensor is not None @@ -1971,6 +1973,8 @@ class FlashCausalLM(Model): num_kv_heads=self.num_kv_heads, head_size=self.head_size, page_size=BLOCK_SIZE, + dtype=self.dtype, + window_left=self.sliding_window, )