diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index 5f8954ea..d603c6f5 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state( num_kv_heads: int, head_size: int, page_size: int, - query_dtype: str = "float16", + dtype: torch.dtype, + window_left: int, ): """ Context manager to set the active flashinfer prefill state to the given @@ -90,8 +91,9 @@ def use_prefill_with_paged_kv_state( num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_size, - q_data_type=query_dtype, + q_data_type=dtype, page_size=page_size, + window_left=window_left, ) yield finally: @@ -119,7 +121,8 @@ def use_prefill_state( num_heads: int, num_kv_heads: int, head_size: int, - query_dtype: str = "float16", + dtype: torch.dtype, + window_left: int, ): """ Context manager to set the active flashinfer prefill state to the given @@ -135,7 +138,8 @@ def use_prefill_state( num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_size, - q_data_type=query_dtype, + q_data_type=dtype, + window_left=window_left, ) yield finally: @@ -200,7 +204,8 @@ def use_decode_state( num_kv_heads: int, head_size: int, page_size: int, - query_dtype: str = "float16", + dtype: torch.dtype, + window_left: int, ): """ Context manager to set the active flashinfer decoding state to the given @@ -235,7 +240,9 @@ def use_decode_state( num_kv_heads=num_kv_heads, head_dim=head_size, page_size=page_size, - q_data_type=query_dtype, + data_type=dtype, + q_data_type=dtype, + window_left=window_left, ) yield finally: diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a2834962..57582ebc 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1960,6 +1960,8 @@ class FlashCausalLM(Model): num_kv_heads=self.num_kv_heads, head_size=self.head_size, page_size=BLOCK_SIZE, + dtype=self.dtype, + window_left=self.sliding_window, ) else: assert input_lengths_tensor is not None @@ -1971,6 +1973,8 @@ class FlashCausalLM(Model): num_kv_heads=self.num_kv_heads, head_size=self.head_size, page_size=BLOCK_SIZE, + dtype=self.dtype, + window_left=self.sliding_window, )