flashinfer: pass window size and dtype (#2574)
This commit is contained in:
parent
5b6b74e21d
commit
1028996fb3
|
@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state(
|
|||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
page_size: int,
|
||||
query_dtype: str = "float16",
|
||||
dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
Context manager to set the active flashinfer prefill state to the given
|
||||
|
@ -90,8 +91,9 @@ def use_prefill_with_paged_kv_state(
|
|||
num_qo_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
q_data_type=query_dtype,
|
||||
q_data_type=dtype,
|
||||
page_size=page_size,
|
||||
window_left=window_left,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
|
@ -119,7 +121,8 @@ def use_prefill_state(
|
|||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
query_dtype: str = "float16",
|
||||
dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
Context manager to set the active flashinfer prefill state to the given
|
||||
|
@ -135,7 +138,8 @@ def use_prefill_state(
|
|||
num_qo_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
q_data_type=query_dtype,
|
||||
q_data_type=dtype,
|
||||
window_left=window_left,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
|
@ -200,7 +204,8 @@ def use_decode_state(
|
|||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
page_size: int,
|
||||
query_dtype: str = "float16",
|
||||
dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
Context manager to set the active flashinfer decoding state to the given
|
||||
|
@ -235,7 +240,9 @@ def use_decode_state(
|
|||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
page_size=page_size,
|
||||
q_data_type=query_dtype,
|
||||
data_type=dtype,
|
||||
q_data_type=dtype,
|
||||
window_left=window_left,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
|
|
|
@ -1960,6 +1960,8 @@ class FlashCausalLM(Model):
|
|||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
page_size=BLOCK_SIZE,
|
||||
dtype=self.dtype,
|
||||
window_left=self.sliding_window,
|
||||
)
|
||||
else:
|
||||
assert input_lengths_tensor is not None
|
||||
|
@ -1971,6 +1973,8 @@ class FlashCausalLM(Model):
|
|||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
page_size=BLOCK_SIZE,
|
||||
dtype=self.dtype,
|
||||
window_left=self.sliding_window,
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue