Add FlashInfer support (#2354)

This change adds support for FlashInfer. FlashInfer can be enabled using
`FLASH_INFER=1` and is currently only implemented in `FlashCausalLM`.
Since this functionality is currently only for testing, FlashInfer is
not installed anywhere yet.

The FlashInfer API is quite different from FlashAttention/vLLM in that
it requires more global bookkeeping:

* A wrapper class needs to be contstructed (which we just call *state*).
  Since this is fairly expensive (due to pinned host memory allocation),
  we only do this once in a FlashCausalLM instance or for each CUDA
  Graph size.
* Each model forward call needs to be wrapped in `begin_forward` and
  `end_forward`. This sets up data structures that can be reused for all
  calls to attention for that forward call.

When calling attention, we need access to the state object. To avoid
passing an argument down the call chain (which would require changes to
all models), we use a context variable.

Each model forward call is wrapped using a context manager that does all
the bookkeeping for such a call:

* Set the context variable to the forward call's state.
* Call `begin_forward` on the state.
* Yield.
* Call `end_forward` on the state.
* Reset the context variable.

We cannot use a single shared global variable for this, since e.g. CUDA
Graphs of different sizes each have their own state.
This commit is contained in:
Daniël de Kok 2024-08-09 11:42:00 +02:00 committed by GitHub
parent 6d06473cf4
commit 7830de1566
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 346 additions and 44 deletions

View File

@ -1,10 +1,10 @@
from dataclasses import dataclass
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.models.globals import FLASH_DECODING, FLASH_INFER
import torch
from typing import Optional
if FLASH_DECODING:
if FLASH_DECODING or FLASH_INFER:
@dataclass
class Seqlen:

View File

@ -1,6 +1,10 @@
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
from text_generation_server.models.globals import (
FLASH_DECODING,
BLOCK_SIZE,
FLASH_INFER,
)
from text_generation_server.layers.attention import Seqlen
from typing import Optional
@ -23,7 +27,7 @@ def reshape_and_cache(
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if FLASH_DECODING:
if FLASH_DECODING or FLASH_INFER:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
@ -72,7 +76,16 @@ def paged_attention(
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
if FLASH_DECODING:
if FLASH_INFER:
from text_generation_server.layers.attention.flash_infer import decode_state
return decode_state.get().forward(
query.contiguous(),
paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
)
elif FLASH_DECODING:
max_q = 1
max_k = max_s
import flash_attn_2_cuda
@ -206,7 +219,32 @@ except ImportError:
SUPPORTS_WINDOWING = V2
if V2:
if FLASH_INFER:
def attention(
q,
k,
v,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
from text_generation_server.layers.attention.flash_infer import prefill_state
return prefill_state.get().forward(
q,
k,
v,
causal=causal,
window_left=window_size_left,
logits_soft_cap=softcap,
sm_scale=softmax_scale,
)
elif V2:
def attention(
q,

View File

@ -0,0 +1,164 @@
from typing import Optional
from contextvars import ContextVar
from contextlib import contextmanager
import flashinfer
import torch
prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar(
"prefill_state"
)
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
"decode_state"
)
workspace: Optional[torch.Tensor] = None
def get_workspace(device):
"""Get shared flashinfer workspace."""
global workspace
if workspace is None:
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
return workspace
def create_prefill_state(
*,
device: torch.device,
):
"""Create a prefill state."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
)
@contextmanager
def use_prefill_state(
*,
state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper,
cu_seqlens: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
query_dtype: str = "float16",
):
"""
Context manager to set the active flashinfer prefill state to the given
`state` and parameters. This state will be used by all calls to the
`attention` function while the context manager is active.
"""
token = prefill_state.set(state)
try:
state.begin_forward(
qo_indptr=cu_seqlens,
kv_indptr=cu_seqlens,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=query_dtype,
)
yield
finally:
state.end_forward()
if token is not None:
prefill_state.reset(token)
def create_decode_state(
*,
device: torch.device,
num_heads: int,
num_kv_heads: int,
):
"""Create a decode state."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=False,
use_tensor_cores=num_heads // num_kv_heads > 4,
)
def create_decode_state_cuda_graphs(
*,
device: torch.device,
block_tables: torch.Tensor,
block_tables_ptr: torch.Tensor,
last_page_len: torch.Tensor,
num_heads: int,
num_kv_heads: int,
):
"""
Create a decode state for use with CUDA Graphs. `block_tables`,
`block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are
therefore stored as part of the state.
"""
workspace_buffer = get_workspace(device)
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=True,
paged_kv_indices_buffer=block_tables,
paged_kv_indptr_buffer=block_tables_ptr,
paged_kv_last_page_len_buffer=last_page_len,
use_tensor_cores=num_heads // num_kv_heads > 4,
)
@contextmanager
def use_decode_state(
*,
state: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
input_lengths: torch.Tensor,
block_tables: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
page_size: int,
query_dtype: str = "float16",
):
"""
Context manager to set the active flashinfer decoding state to the given
`state` and parameters. This state will be used by all calls to the
`paged_attention` function while the context manager is active.
"""
indptr = torch.zeros(
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
)
# Round up to page size and then calculate the cumulative sum to get
# the indices into the block table.
torch.add(input_lengths, page_size - 1, out=indptr[1:])
indptr[1:].div_(page_size, rounding_mode="floor")
indptr[1:].cumsum_(-1)
# Get the lengths of the last page in a block.
last_page_len = torch.empty(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
torch.sub(input_lengths, 1, out=last_page_len)
last_page_len.remainder_(page_size)
last_page_len += 1
token = decode_state.set(state)
try:
state.begin_forward(
indptr=indptr,
indices=block_tables,
last_page_len=last_page_len,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
page_size=page_size,
q_data_type=query_dtype,
)
yield
finally:
state.end_forward()
if token is not None:
decode_state.reset(token)

View File

@ -1,3 +1,4 @@
from contextlib import nullcontext
import math
import os
import time
@ -15,7 +16,7 @@ from transformers import (
AutoTokenizer,
GenerationConfig,
)
from typing import Iterable, Optional, Tuple, List, Type, Dict
from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
@ -40,6 +41,7 @@ from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import (
MEM_POOL,
FLASH_DECODING,
FLASH_INFER,
BLOCK_SIZE,
CUDA_GRAPHS,
get_adapter_to_index,
@ -907,6 +909,7 @@ class FlashCausalLM(Model):
config.sliding_window = None
self.num_layers = config.num_hidden_layers
self.num_heads = config.num_attention_heads
# Validation is done in the model itself
if num_kv_heads is None:
num_kv_heads = getattr(config, "num_key_value_heads", None)
@ -935,6 +938,21 @@ class FlashCausalLM(Model):
self.cuda_graphs = {}
self.kv_cache = []
if FLASH_INFER:
from text_generation_server.layers.attention.flash_infer import (
create_prefill_state,
create_decode_state,
)
self.prefill_state = create_prefill_state(device=device)
if not CUDA_GRAPHS:
self.decode_state = create_decode_state(
device=device,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
super().__init__(
model_id=model_id,
model=model,
@ -972,7 +990,7 @@ class FlashCausalLM(Model):
else:
x = BLOCK_SIZE // element_size
if FLASH_DECODING:
if FLASH_DECODING or FLASH_INFER:
self.kv_cache = [
(
torch.empty(
@ -1044,38 +1062,66 @@ class FlashCausalLM(Model):
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
if FLASH_INFER:
from text_generation_server.layers.attention.flash_infer import (
create_decode_state_cuda_graphs,
)
block_tables_ptr = torch.zeros(
bs + 1, dtype=torch.int32, device=self.device
)
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
state = create_decode_state_cuda_graphs(
device=input_ids.device,
block_tables=block_tables.view(-1),
block_tables_ptr=block_tables_ptr,
last_page_len=last_page_len,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
self.cuda_graphs[bs]["state"] = state
else:
state = None
torch.cuda.synchronize()
# Run once outside to warmup
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
with self._forward_context(
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths_,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward(
cu_seqlen_prefill=None,
input_lengths=input_lengths,
state=state,
):
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
input_lengths=input_lengths_,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
self.cuda_graphs[bs]["logits"] = logits
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
self.cuda_graphs[bs]["logits"] = logits
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
torch.cuda.synchronize()
def warmup(self, batch: FlashCausalLMBatch):
@ -1295,23 +1341,28 @@ class FlashCausalLM(Model):
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
with self._forward_context(
block_tables=block_tables,
slots=slots,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
adapter_data=adapter_data,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
return logits, speculative_logits
):
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
adapter_data=adapter_data,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
@ -1325,8 +1376,16 @@ class FlashCausalLM(Model):
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
# Replay the graph
cuda_graph["graph"].replay()
state = cuda_graph.get("state")
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=None,
input_lengths=input_lengths,
state=state,
):
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
@ -1698,3 +1757,39 @@ class FlashCausalLM(Model):
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)
def _forward_context(
self,
*,
block_tables: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
input_lengths: torch.Tensor,
state: Optional[Any] = None,
) -> ContextManager:
if not FLASH_INFER:
return nullcontext()
from text_generation_server.layers.attention.flash_infer import (
use_decode_state,
use_prefill_state,
)
if cu_seqlen_prefill is not None:
return use_prefill_state(
state=state if state is not None else self.prefill_state,
cu_seqlens=cu_seqlen_prefill,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
)
else:
assert input_lengths is not None
return use_decode_state(
state=state if state is not None else self.decode_state,
input_lengths=input_lengths,
block_tables=block_tables.view(-1),
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
page_size=BLOCK_SIZE,
)

View File

@ -5,6 +5,10 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master
FLASH_INFER = os.getenv("FLASH_INFER") in {"1", "true", "True"}
if FLASH_INFER:
log_master(logger.info, "Using FLASH_INFER")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
@ -12,6 +16,7 @@ BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
if FLASH_DECODING:
log_master(logger.info, "Using FLASH_DECODING")
cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None:
try: