parent
422bf1f986
commit
b5f1c9de06
|
@ -391,6 +391,28 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
def batch_type(self) -> Type[FlashMistralBatch]:
|
||||
return FlashMistralBatch
|
||||
|
||||
def tunableop_warmup(self, seqlen: int):
|
||||
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=torch.tensor(
|
||||
[0, seqlen], device=self.device, dtype=torch.int32
|
||||
),
|
||||
kv_cache=get_cache_manager().kv_cache,
|
||||
block_tables=None,
|
||||
input_lengths=None,
|
||||
slots=slots,
|
||||
max_s=seqlen,
|
||||
lm_head_indices=None,
|
||||
prefill_cache_indices=None,
|
||||
)
|
||||
|
||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
|
|
|
@ -522,6 +522,30 @@ class Mamba(Model):
|
|||
}
|
||||
self.cuda_graphs[batch_size] = graph_dict
|
||||
|
||||
def tunableop_warmup(self, seqlen: int):
|
||||
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
|
||||
n_blocks = len(self.model.blocks)
|
||||
|
||||
d_state = self.model.config.d_state
|
||||
d_conv = self.model.config.d_conv
|
||||
# Inner takes the expand multiplication
|
||||
d_inner = self.model.config.d_inner
|
||||
|
||||
# Important seqlen_offset to go through the update mecanism with the state
|
||||
seqlen_offset = 1
|
||||
inference_params = new_inference_params(
|
||||
n_blocks=n_blocks,
|
||||
batch_size=seqlen,
|
||||
d_state=d_state,
|
||||
d_conv=d_conv,
|
||||
d_inner=d_inner,
|
||||
seqlen_offset=seqlen_offset,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.model.forward(input_ids=input_ids, inference_params=inference_params)
|
||||
|
||||
def forward(
|
||||
self, input_ids: torch.Tensor, inference_params: Any
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
|
Loading…
Reference in New Issue