from dataclasses import dataclass import torch from typing import Optional @dataclass class Seqlen: input_lengths: torch.Tensor cache_lengths: torch.Tensor cu_seqlen_q: Optional[torch.Tensor] cu_seqlen_k: Optional[torch.Tensor] max_q: int max_k: int def __init__( self, input_lengths, cache_lengths, cu_seqlen_q=None, max_q=None, max_k=None, ): self.input_lengths = input_lengths self.cache_lengths = cache_lengths device = self.input_lengths.device shape = self.input_lengths.shape if cu_seqlen_q is None: cu_seqlen_q = torch.arange( shape[0] + 1, device=device, dtype=torch.int32, ) max_q = 1 else: assert max_q is not None assert max_k is not None cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) # cuda graphs don't like this and this is necessary to clamp within mistral # Although FA2 might not want the clamping # cu_seqlen_k[0] = 0 total = self.input_lengths + self.cache_lengths torch.cumsum(total, -1, out=cu_seqlen_k[1:]) self.cu_seqlen_q = cu_seqlen_q self.cu_seqlen_k = cu_seqlen_k self.max_q = max_q self.max_k = max_k def clamp(self, max): # Flash decoding doesn't need to clamp return self