From 7830de1566df365e6cb9ce0a955e8e2ac1b28ca6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 9 Aug 2024 11:42:00 +0200 Subject: [PATCH] Add FlashInfer support (#2354) This change adds support for FlashInfer. FlashInfer can be enabled using `FLASH_INFER=1` and is currently only implemented in `FlashCausalLM`. Since this functionality is currently only for testing, FlashInfer is not installed anywhere yet. The FlashInfer API is quite different from FlashAttention/vLLM in that it requires more global bookkeeping: * A wrapper class needs to be contstructed (which we just call *state*). Since this is fairly expensive (due to pinned host memory allocation), we only do this once in a FlashCausalLM instance or for each CUDA Graph size. * Each model forward call needs to be wrapped in `begin_forward` and `end_forward`. This sets up data structures that can be reused for all calls to attention for that forward call. When calling attention, we need access to the state object. To avoid passing an argument down the call chain (which would require changes to all models), we use a context variable. Each model forward call is wrapped using a context manager that does all the bookkeeping for such a call: * Set the context variable to the forward call's state. * Call `begin_forward` on the state. * Yield. * Call `end_forward` on the state. * Reset the context variable. We cannot use a single shared global variable for this, since e.g. CUDA Graphs of different sizes each have their own state. --- .../layers/attention/common.py | 4 +- .../layers/attention/cuda.py | 46 ++++- .../layers/attention/flash_infer.py | 164 +++++++++++++++++ .../models/flash_causal_lm.py | 171 ++++++++++++++---- .../text_generation_server/models/globals.py | 5 + 5 files changed, 346 insertions(+), 44 deletions(-) create mode 100644 server/text_generation_server/layers/attention/flash_infer.py diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index bd0717ce..b986a082 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from text_generation_server.models.globals import FLASH_DECODING +from text_generation_server.models.globals import FLASH_DECODING, FLASH_INFER import torch from typing import Optional -if FLASH_DECODING: +if FLASH_DECODING or FLASH_INFER: @dataclass class Seqlen: diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 96b654d0..1b8e9209 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,6 +1,10 @@ import torch from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE +from text_generation_server.models.globals import ( + FLASH_DECODING, + BLOCK_SIZE, + FLASH_INFER, +) from text_generation_server.layers.attention import Seqlen from typing import Optional @@ -23,7 +27,7 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - if FLASH_DECODING: + if FLASH_DECODING or FLASH_INFER: shape = key_cache.shape key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value @@ -72,7 +76,16 @@ def paged_attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - if FLASH_DECODING: + if FLASH_INFER: + from text_generation_server.layers.attention.flash_infer import decode_state + + return decode_state.get().forward( + query.contiguous(), + paged_kv_cache=(key_cache, value_cache), + logits_soft_cap=softcap, + sm_scale=softmax_scale, + ) + elif FLASH_DECODING: max_q = 1 max_k = max_s import flash_attn_2_cuda @@ -206,7 +219,32 @@ except ImportError: SUPPORTS_WINDOWING = V2 -if V2: +if FLASH_INFER: + + def attention( + q, + k, + v, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + causal=True, + softcap=0.0, + ): + from text_generation_server.layers.attention.flash_infer import prefill_state + + return prefill_state.get().forward( + q, + k, + v, + causal=causal, + window_left=window_size_left, + logits_soft_cap=softcap, + sm_scale=softmax_scale, + ) + +elif V2: def attention( q, diff --git a/server/text_generation_server/layers/attention/flash_infer.py b/server/text_generation_server/layers/attention/flash_infer.py new file mode 100644 index 00000000..56b53b2c --- /dev/null +++ b/server/text_generation_server/layers/attention/flash_infer.py @@ -0,0 +1,164 @@ +from typing import Optional +from contextvars import ContextVar +from contextlib import contextmanager + +import flashinfer +import torch + +prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar( + "prefill_state" +) + +decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar( + "decode_state" +) + +workspace: Optional[torch.Tensor] = None + + +def get_workspace(device): + """Get shared flashinfer workspace.""" + global workspace + if workspace is None: + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + return workspace + + +def create_prefill_state( + *, + device: torch.device, +): + """Create a prefill state.""" + workspace_buffer = get_workspace(device) + return flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, kv_layout="NHD", use_cuda_graph=False + ) + + +@contextmanager +def use_prefill_state( + *, + state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper, + cu_seqlens: torch.Tensor, + num_heads: int, + num_kv_heads: int, + head_size: int, + query_dtype: str = "float16", +): + """ + Context manager to set the active flashinfer prefill state to the given + `state` and parameters. This state will be used by all calls to the + `attention` function while the context manager is active. + """ + + token = prefill_state.set(state) + try: + state.begin_forward( + qo_indptr=cu_seqlens, + kv_indptr=cu_seqlens, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + q_data_type=query_dtype, + ) + yield + finally: + state.end_forward() + if token is not None: + prefill_state.reset(token) + + +def create_decode_state( + *, + device: torch.device, + num_heads: int, + num_kv_heads: int, +): + """Create a decode state.""" + workspace_buffer = get_workspace(device) + return flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout="NHD", + use_cuda_graph=False, + use_tensor_cores=num_heads // num_kv_heads > 4, + ) + + +def create_decode_state_cuda_graphs( + *, + device: torch.device, + block_tables: torch.Tensor, + block_tables_ptr: torch.Tensor, + last_page_len: torch.Tensor, + num_heads: int, + num_kv_heads: int, +): + """ + Create a decode state for use with CUDA Graphs. `block_tables`, + `block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are + therefore stored as part of the state. + """ + workspace_buffer = get_workspace(device) + return flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout="NHD", + use_cuda_graph=True, + paged_kv_indices_buffer=block_tables, + paged_kv_indptr_buffer=block_tables_ptr, + paged_kv_last_page_len_buffer=last_page_len, + use_tensor_cores=num_heads // num_kv_heads > 4, + ) + + +@contextmanager +def use_decode_state( + *, + state: flashinfer.BatchDecodeWithPagedKVCacheWrapper, + input_lengths: torch.Tensor, + block_tables: torch.Tensor, + num_heads: int, + num_kv_heads: int, + head_size: int, + page_size: int, + query_dtype: str = "float16", +): + """ + Context manager to set the active flashinfer decoding state to the given + `state` and parameters. This state will be used by all calls to the + `paged_attention` function while the context manager is active. + """ + indptr = torch.zeros( + input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 + ) + # Round up to page size and then calculate the cumulative sum to get + # the indices into the block table. + torch.add(input_lengths, page_size - 1, out=indptr[1:]) + indptr[1:].div_(page_size, rounding_mode="floor") + indptr[1:].cumsum_(-1) + + # Get the lengths of the last page in a block. + last_page_len = torch.empty( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) + torch.sub(input_lengths, 1, out=last_page_len) + last_page_len.remainder_(page_size) + last_page_len += 1 + + token = decode_state.set(state) + + try: + state.begin_forward( + indptr=indptr, + indices=block_tables, + last_page_len=last_page_len, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + page_size=page_size, + q_data_type=query_dtype, + ) + yield + finally: + state.end_forward() + if token is not None: + decode_state.reset(token) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 36bb2662..12aa7dcd 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext import math import os import time @@ -15,7 +16,7 @@ from transformers import ( AutoTokenizer, GenerationConfig, ) -from typing import Iterable, Optional, Tuple, List, Type, Dict +from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE @@ -40,6 +41,7 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.globals import ( MEM_POOL, FLASH_DECODING, + FLASH_INFER, BLOCK_SIZE, CUDA_GRAPHS, get_adapter_to_index, @@ -907,6 +909,7 @@ class FlashCausalLM(Model): config.sliding_window = None self.num_layers = config.num_hidden_layers + self.num_heads = config.num_attention_heads # Validation is done in the model itself if num_kv_heads is None: num_kv_heads = getattr(config, "num_key_value_heads", None) @@ -935,6 +938,21 @@ class FlashCausalLM(Model): self.cuda_graphs = {} self.kv_cache = [] + if FLASH_INFER: + from text_generation_server.layers.attention.flash_infer import ( + create_prefill_state, + create_decode_state, + ) + + self.prefill_state = create_prefill_state(device=device) + + if not CUDA_GRAPHS: + self.decode_state = create_decode_state( + device=device, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + super().__init__( model_id=model_id, model=model, @@ -972,7 +990,7 @@ class FlashCausalLM(Model): else: x = BLOCK_SIZE // element_size - if FLASH_DECODING: + if FLASH_DECODING or FLASH_INFER: self.kv_cache = [ ( torch.empty( @@ -1044,38 +1062,66 @@ class FlashCausalLM(Model): graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph + if FLASH_INFER: + from text_generation_server.layers.attention.flash_infer import ( + create_decode_state_cuda_graphs, + ) + + block_tables_ptr = torch.zeros( + bs + 1, dtype=torch.int32, device=self.device + ) + last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) + state = create_decode_state_cuda_graphs( + device=input_ids.device, + block_tables=block_tables.view(-1), + block_tables_ptr=block_tables_ptr, + last_page_len=last_page_len, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + self.cuda_graphs[bs]["state"] = state + else: + state = None + torch.cuda.synchronize() # Run once outside to warmup - self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=None, - kv_cache=self.kv_cache, + with self._forward_context( block_tables=block_tables, - slots=slots, - input_lengths=input_lengths_, - max_s=max_s, - prefill_cache_indices=None, - lm_head_indices=None, - ) - torch.cuda.synchronize() - - with torch.cuda.graph(graph, pool=MEM_POOL): - input_lengths = Seqlen(input_lengths=input_lengths) - logits, speculative_logits = self.model.forward( + cu_seqlen_prefill=None, + input_lengths=input_lengths, + state=state, + ): + self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + input_lengths=input_lengths_, max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, ) - self.cuda_graphs[bs]["logits"] = logits - self.cuda_graphs[bs]["speculative_logits"] = speculative_logits + + torch.cuda.synchronize() + + with torch.cuda.graph(graph, pool=MEM_POOL): + input_lengths = Seqlen(input_lengths=input_lengths) + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + ) + self.cuda_graphs[bs]["logits"] = logits + self.cuda_graphs[bs]["speculative_logits"] = speculative_logits torch.cuda.synchronize() def warmup(self, batch: FlashCausalLMBatch): @@ -1295,23 +1341,28 @@ class FlashCausalLM(Model): cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - input_lengths = Seqlen(input_lengths=input_lengths) - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, + with self._forward_context( block_tables=block_tables, - slots=slots, + cu_seqlen_prefill=cu_seqlen_prefill, input_lengths=input_lengths, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - adapter_data=adapter_data, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - return logits, speculative_logits + ): + input_lengths = Seqlen(input_lengths=input_lengths) + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + adapter_data=adapter_data, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded @@ -1325,8 +1376,16 @@ class FlashCausalLM(Model): cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths - # Replay the graph - cuda_graph["graph"].replay() + state = cuda_graph.get("state") + with self._forward_context( + block_tables=block_tables, + cu_seqlen_prefill=None, + input_lengths=input_lengths, + state=state, + ): + # Replay the graph + cuda_graph["graph"].replay() + # Slice output to the correct shape speculative_logits = ( cuda_graph["speculative_logits"][:bs] @@ -1698,3 +1757,39 @@ class FlashCausalLM(Model): forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) + + def _forward_context( + self, + *, + block_tables: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + input_lengths: torch.Tensor, + state: Optional[Any] = None, + ) -> ContextManager: + if not FLASH_INFER: + return nullcontext() + + from text_generation_server.layers.attention.flash_infer import ( + use_decode_state, + use_prefill_state, + ) + + if cu_seqlen_prefill is not None: + return use_prefill_state( + state=state if state is not None else self.prefill_state, + cu_seqlens=cu_seqlen_prefill, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + ) + else: + assert input_lengths is not None + return use_decode_state( + state=state if state is not None else self.decode_state, + input_lengths=input_lengths, + block_tables=block_tables.view(-1), + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + page_size=BLOCK_SIZE, + ) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 8d2431db..42b43c87 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,6 +5,10 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master +FLASH_INFER = os.getenv("FLASH_INFER") in {"1", "true", "True"} +if FLASH_INFER: + log_master(logger.info, "Using FLASH_INFER") + MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} @@ -12,6 +16,7 @@ BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 if FLASH_DECODING: log_master(logger.info, "Using FLASH_DECODING") + cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: try: