Fix TunableOp bug (#1920)

cc @Narsil
This commit is contained in:
fxmarty 2024-05-17 18:21:51 +02:00 committed by GitHub
parent 422bf1f986
commit b5f1c9de06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 1 deletions

View File

@ -391,6 +391,28 @@ class BaseFlashMistral(FlashCausalLM):
def batch_type(self) -> Type[FlashMistralBatch]:
return FlashMistralBatch
def tunableop_warmup(self, seqlen: int):
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
kv_cache = get_cache_manager().kv_cache
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
),
kv_cache=get_cache_manager().kv_cache,
block_tables=None,
input_lengths=None,
slots=slots,
max_s=seqlen,
lm_head_indices=None,
prefill_cache_indices=None,
)
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)

View File

@ -522,6 +522,30 @@ class Mamba(Model):
}
self.cuda_graphs[batch_size] = graph_dict
def tunableop_warmup(self, seqlen: int):
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
n_blocks = len(self.model.blocks)
d_state = self.model.config.d_state
d_conv = self.model.config.d_conv
# Inner takes the expand multiplication
d_inner = self.model.config.d_inner
# Important seqlen_offset to go through the update mecanism with the state
seqlen_offset = 1
inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=seqlen,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=self.device,
dtype=self.dtype,
)
self.model.forward(input_ids=input_ids, inference_params=inference_params)
def forward(
self, input_ids: torch.Tensor, inference_params: Any
) -> Tuple[torch.Tensor, torch.Tensor]: