From 9b3674d9038a62c42534e1b3ee2d56257dd214ff Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 10 Jun 2024 09:09:50 +0200 Subject: [PATCH 1/2] ROCm and sliding windows fixes (#2033) * update vllm commit & fix models using sliding window * update * update commit * fix bug where tunableop is bound to cuda graph even when cuda graph are disabled * enable tunableop by default * fix sliding window * address review * dead code * precise comment * is it flaky? --- launcher/src/main.rs | 8 +++++++ server/Makefile-vllm | 4 ++-- server/text_generation_server/cli.py | 2 ++ .../layers/attention/rocm.py | 11 +++------- .../layers/attention/xpu.py | 5 +---- .../text_generation_server/models/__init__.py | 21 +++++++++++-------- .../models/flash_causal_lm.py | 7 ++++++- server/text_generation_server/server.py | 2 ++ 8 files changed, 36 insertions(+), 24 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index c40a8461..e4d5bb85 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -481,6 +481,7 @@ fn shard_manager( rope_factor: Option, max_total_tokens: usize, max_batch_size: Option, + max_input_tokens: usize, otlp_endpoint: Option, log_level: LevelFilter, status_sender: mpsc::Sender, @@ -553,6 +554,10 @@ fn shard_manager( shard_args.push(otlp_endpoint); } + // In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter. + shard_args.push("--max-input-tokens".to_string()); + shard_args.push(max_input_tokens.to_string()); + // Copy current process env let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); @@ -1009,6 +1014,7 @@ fn spawn_shards( args: &Args, cuda_graphs: Vec, max_total_tokens: usize, + max_input_tokens: usize, max_log_level: LevelFilter, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, @@ -1066,6 +1072,7 @@ fn spawn_shards( rope_factor, max_total_tokens, max_batch_size, + max_input_tokens, otlp_endpoint, max_log_level, status_sender, @@ -1540,6 +1547,7 @@ fn main() -> Result<(), LauncherError> { &args, cuda_graphs, max_total_tokens, + max_input_tokens, max_log_level, shutdown.clone(), &shutdown_receiver, diff --git a/server/Makefile-vllm b/server/Makefile-vllm index ded2f5d2..8c0437ea 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,5 +1,5 @@ commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa -commit_rocm := ca6913b3c2ffacdcb7d15e914dc34adbc6c89479 +commit_rocm := 559200c1a028de990c1ddea761b0ccd62109e3a0 build-vllm-cuda: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ @@ -19,5 +19,5 @@ build-vllm-rocm: PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build install-vllm-rocm: build-vllm-rocm - cd vllm && git fetch && git checkout $(commit_rocm) && \ + cd vllm && git fetch && git checkout $(commit_rocm) && \ PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e . diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 68b429d0..430323bc 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -42,6 +42,7 @@ def serve( logger_level: str = "INFO", json_output: bool = False, otlp_endpoint: Optional[str] = None, + max_input_tokens: Optional[int] = None, ): if sharded: assert ( @@ -98,6 +99,7 @@ def serve( dtype, trust_remote_code, uds_path, + max_input_tokens, ) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 535810aa..91ed5818 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -169,10 +169,8 @@ if ENGINE == "ck": ): if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") - if window_size_left != -1: - raise ValueError( - f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." - ) + + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return flash_attn_2_cuda.varlen_fwd( q, k, @@ -204,10 +202,7 @@ elif ENGINE == "triton": window_size_left=-1, causal=True, ): - if window_size_left != -1: - raise ValueError( - f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." - ) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. output, _ = triton_attention( q, k, diff --git a/server/text_generation_server/layers/attention/xpu.py b/server/text_generation_server/layers/attention/xpu.py index d9a096f9..8b6cb87b 100644 --- a/server/text_generation_server/layers/attention/xpu.py +++ b/server/text_generation_server/layers/attention/xpu.py @@ -14,10 +14,7 @@ def attention( softmax_scale, window_size_left=-1, ): - if window_size_left != -1: - raise ValueError( - f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." - ) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return ipex.llm.functional.varlen_attention( q, k, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index ba353c11..a61cb83b 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -24,6 +24,8 @@ from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.phi import Phi +from text_generation_server.utils.import_utils import SYSTEM + # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. torch.backends.cuda.matmul.allow_tf32 = True @@ -257,6 +259,7 @@ def get_model( speculate: Optional[int], dtype: Optional[str], trust_remote_code: bool, + max_input_tokens: int, ) -> Model: global FLASH_ATTENTION if dtype is None: @@ -410,11 +413,15 @@ def get_model( "Sharding is currently not supported with `exl2` quantization" ) sliding_window = config_dict.get("sliding_window", -1) - if sliding_window != -1 and not SUPPORTS_WINDOWING: - logger.warning( - f"Flash attention is available, but doesn't support windowing which is required by model {model_id}" + + if ( + (sliding_window is not None and sliding_window != -1) + and not SUPPORTS_WINDOWING + and max_input_tokens > sliding_window + ): + raise ValueError( + f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." ) - FLASH_ATTENTION = False if model_type == MAMBA: return Mamba( @@ -701,7 +708,6 @@ def get_model( ) if model_type == MISTRAL: - sliding_window = config_dict.get("sliding_window", -1) if FLASH_ATTENTION: return FlashMistral( model_id, @@ -724,7 +730,6 @@ def get_model( ) if model_type == MIXTRAL: - sliding_window = config_dict.get("sliding_window", -1) if FLASH_ATTENTION: return FlashMixtral( model_id, @@ -747,7 +752,6 @@ def get_model( ) if model_type == STARCODER2: - sliding_window = config_dict.get("sliding_window", -1) if FLASH_ATTENTION: return FlashStarcoder2( model_id, @@ -771,8 +775,7 @@ def get_model( ) if model_type == QWEN2: - sliding_window = config_dict.get("sliding_window", -1) - if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING: + if FLASH_ATTENTION: return FlashQwen2( model_id, revision, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index acf77b09..d16d3710 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -902,6 +902,8 @@ class FlashCausalLM(Model): os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1" ): + torch.cuda.tunable.enable() + if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0": torch.cuda.tunable.tuning_enable(True) @@ -910,8 +912,11 @@ class FlashCausalLM(Model): int(val) for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",") ] - else: + elif CUDA_GRAPHS is not None: tuning_sequences = CUDA_GRAPHS + else: + # For seqlen = 1, we dispatch to LLMM1 kernel. + tuning_sequences = [2, 3, 4, 5, 6, 7] tunableop_filepath = os.path.join( HUGGINGFACE_HUB_CACHE, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 4118b3f6..569b6925 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -199,6 +199,7 @@ def serve( dtype: Optional[str], trust_remote_code: bool, uds_path: Path, + max_input_tokens: int, ): async def serve_inner( model_id: str, @@ -229,6 +230,7 @@ def serve( speculate, dtype, trust_remote_code, + max_input_tokens, ) except Exception: logger.exception("Error when initializing model") From 85dfc39222798b75559c891789283de23c679ca5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 10 Jun 2024 09:22:29 +0200 Subject: [PATCH 2/2] Add Phi-3 medium support (#2039) Add support for Phi-3-medium The main difference between the medium and mini models is that medium uses grouped query attention with a packed QKV matrix. This change adds support for GQA with packed matrixes to `Weights.get_weights_col_packed` and uses it for Phi-3. This also allows us to remove the custom implementation of GQA from dbrx attention loading. --- .../layers/tensor_parallel.py | 17 ++- .../custom_modeling/flash_dbrx_modeling.py | 131 +----------------- .../custom_modeling/flash_gpt2_modeling.py | 7 +- .../custom_modeling/flash_llama_modeling.py | 9 ++ .../text_generation_server/utils/weights.py | 118 +++++++++++----- 5 files changed, 118 insertions(+), 164 deletions(-) diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 192c2b42..6005f737 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -129,9 +129,22 @@ class TensorParallelColumnLinear(SuperLayer): return cls(linear) @classmethod - def load_qkv(cls, config, prefix: str, weights, bias: bool): + def load_qkv( + cls, + config, + prefix: str, + weights, + bias: bool, + num_heads: int, + num_key_value_heads: int, + ): """Specific method when the QKV was joined after the fact""" - weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize) + weight = weights.get_weights_col_packed_qkv( + prefix, + quantize=config.quantize, + num_heads=num_heads, + num_key_value_heads=num_key_value_heads, + ) if bias: raise NotImplementedError("packed_qkv only implemented for baichuan") else: diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 63ce6543..94cf6452 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -20,7 +20,6 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any -from loguru import logger from text_generation_server.utils.import_utils import SYSTEM if SYSTEM != "xpu": @@ -164,129 +163,13 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor: def load_attention(config, prefix, weights): - if config.n_heads != config.attn_config.kv_n_heads: - return _load_gqa(config, prefix, weights) - else: - return TensorParallelColumnLinear.load_qkv( - config, - prefix=f"{prefix}.Wqkv", - weights=weights, - bias=False, - ) - - -def _load_gqa(config, prefix: str, weights): - assert config.d_model % config.n_heads == 0 - assert config.n_heads % weights.process_group.size() == 0 - - head_dim = config.d_model // config.n_heads - world_size = weights.process_group.size() - rank = weights.process_group.rank() - - q_block_size = config.d_model // world_size - q_start = rank * q_block_size - q_stop = (rank + 1) * q_block_size - - kv_block_size = (config.attn_config.kv_n_heads * head_dim) // world_size - k_offset = config.d_model - k_start = k_offset + rank * kv_block_size - k_stop = k_offset + (rank + 1) * kv_block_size - - v_offset = config.d_model + config.attn_config.kv_n_heads * head_dim - v_start = v_offset + rank * kv_block_size - v_stop = v_offset + (rank + 1) * kv_block_size - - if config.quantize in ["gptq", "awq"]: - from text_generation_server.layers.gptq import GPTQWeight - - try: - qweight_slice = weights._get_slice(f"{prefix}.qweight") - q_qweight = qweight_slice[:, q_start:q_stop] - k_qweight = qweight_slice[:, k_start:k_stop] - v_qweight = qweight_slice[:, v_start:v_stop] - - qweight = torch.cat([q_qweight, k_qweight, v_qweight], dim=1) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{config.quantize}` weight, make sure the model is already quantized" - ) - - qzeros_slice = weights._get_slice(f"{prefix}.qzeros") - q_qzeros = qzeros_slice[:, q_start:q_stop] - k_qzeros = qzeros_slice[:, k_start:k_stop] - v_qzeros = qzeros_slice[:, v_start:v_stop] - - qzeros = torch.cat([q_qzeros, k_qzeros, v_qzeros], dim=1) - - scales_slice = weights._get_slice(f"{prefix}.scales") - q_scales = scales_slice[:, q_start:q_stop] - k_scales = scales_slice[:, k_start:k_stop] - v_scales = scales_slice[:, v_start:v_stop] - - scales = torch.cat([q_scales, k_scales, v_scales], dim=1) - - bits, groupsize, desc_act, quant_method = weights._get_gptq_params() - - from text_generation_server.layers import HAS_EXLLAMA - - use_exllama = ( - bits == 4 and HAS_EXLLAMA and config.quantize == "gptq" and not desc_act - ) - - if config.quantize == "gptq" and quant_method == "gptq": - g_idx_slice = weights._get_slice(f"{prefix}.g_idx") - q_g_idx = g_idx_slice[:, q_start:q_stop] - k_g_idx = g_idx_slice[:, k_start:k_stop] - v_g_idx = g_idx_slice[:, v_start:v_stop] - - w = [q_g_idx, k_g_idx, v_g_idx] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - elif config.quantize == "gptq" and quant_method == "awq": - log_once( - logger.info, "Converting AWQ model to Exllama/GPTQ packing format." - ) - from text_generation_server.layers.awq.conveersion_utils import ( - fast_awq_to_gptq, - ) - - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - if use_exllama: - g_idx = None - else: - g_idx = ( - torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device) - // groupsize - ).to(dtype=torch.int32) - else: - g_idx = None - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=bits, - groupsize=groupsize, - use_exllama=use_exllama, - ) - elif config.quantize == "marlin": - # NOTE: at the time marlin support was added, the only model that - # exists is LnL-AI/dbrx-base-converted-v2-4bit-gptq-marlin(-v2), - # but it requires manual concatenation of weight files. - raise RuntimeError("dbrx models with marlin quantization are not yet supported") - else: - qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight") - q = qkv_slice[q_start:q_stop] - k = qkv_slice[k_start:k_stop] - v = qkv_slice[v_start:v_stop] - - weight = torch.cat([q, k, v], dim=0) - weight = weight.to(dtype=weights.dtype).to(device=weights.device) - - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) + return TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.Wqkv", + weights=weights, + bias=False, + num_heads=config.n_heads, + num_key_value_heads=config.attn_config.kv_n_heads, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 0178c911..0c01f56a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -59,7 +59,12 @@ def _load_qkv_gptq(config, prefix: str, weights): rank = weights.process_group.rank() # Weights - weight = weights.get_weights_col_packed_qkv(f"{prefix}.c_attn", config.quantize) + weight = weights.get_weights_col_packed_qkv( + f"{prefix}.c_attn", + config.quantize, + config.num_attention_heads, + config.num_attention_heads, + ) # Bias slice_ = weights._get_slice(f"{prefix}.c_attn.bias") diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index cef712f0..0d06d104 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -62,6 +62,8 @@ def load_attention(config, prefix, weights): prefix=f"{prefix}.qkv_proj", weights=weights, bias=bias, + num_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, ) elif config.model_type == "baichuan": return TensorParallelColumnLinear.load_qkv( @@ -69,6 +71,8 @@ def load_attention(config, prefix, weights): prefix=f"{prefix}.W_pack", weights=weights, bias=bias, + num_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, ) # otherwise, load the default attention based on the number of heads @@ -107,6 +111,11 @@ class FlashLlamaAttention(torch.nn.Module): f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " f"and `num_shards`: {weights.process_group.size()}" ) + if config.num_key_value_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = ( config.num_key_value_heads // weights.process_group.size() diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 557656e7..4d5fcb25 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,7 +1,6 @@ -from dataclasses import dataclass, field import os from pathlib import Path -from typing import List, Dict, Optional, Set, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union from safetensors import safe_open, SafetensorError import torch from loguru import logger @@ -121,49 +120,62 @@ class Weights: ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" return self.get_partial_sharded(tensor_name, dim) - def _get_qweight(self, name: str, blocks: int): + def _get_qweight(self, name: str, block_sizes: Union[int, List[int]]): slice_ = self._get_slice(name) total_size = slice_.get_shape()[1] - assert ( - total_size % blocks == 0 - ), f"Prepacked quantized matrix is not divisible by {blocks}" - single_size = total_size // blocks + block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes) + world_size = self.process_group.size() rank = self.process_group.rank() - assert ( - single_size % world_size == 0 - ), f"Prepacked quantized matrix cannot be sharded across {world_size} shards" - block_size = single_size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - weights = [] - for block in range(blocks): - weights.append( - slice_[:, start + block * single_size : stop + block * single_size] - ) + block_offset = 0 + for block_size in block_sizes: + assert ( + block_size % world_size == 0 + ), f"Prepacked qkv cannot be sharded across {world_size} shards" + shard_block_size = block_size // world_size + start = rank * shard_block_size + stop = (rank + 1) * shard_block_size + weights.append(slice_[:, block_offset + start : block_offset + stop]) + block_offset += block_size weight = torch.cat(weights, dim=1) weight = weight.to(device=self.device) return weight - def get_weights_col_packed_qkv(self, prefix: str, quantize: str): - return self.get_weights_col_packed(prefix, quantize, 3) + def get_weights_col_packed_qkv( + self, + prefix: str, + quantize: str, + num_heads: int, + num_key_value_heads: int, + ): + return self.get_weights_col_packed( + prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads] + ) def get_weights_col_packed_gate_up(self, prefix: str, quantize: str): return self.get_weights_col_packed(prefix, quantize, 2) - def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int): + def get_weights_col_packed( + self, prefix: str, quantize: str, block_sizes: Union[int, List[int]] + ): """ Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being - already alternating Q,K,V within the main tensor + already alternating Q,K,V within the main tensor. + + The columns are split in equally sized blocks when blocks is an `int`, or + in blocks proportional given to the sizes. For instance `[2, 1, 1]` will + divide an input with dimensionality `1024` in `[512, 256, 256]`. This is + convenient for e.g. splitting QKV without knowing the storage details of + quantized weights. """ if quantize in ["gptq", "awq"]: from text_generation_server.layers.gptq import GPTQWeight try: - qweight = self._get_qweight(f"{prefix}.qweight", blocks) + qweight = self._get_qweight(f"{prefix}.qweight", block_sizes) except RuntimeError: raise RuntimeError( f"Cannot load `{quantize}` weight, make sure the model is already quantized." @@ -171,8 +183,8 @@ class Weights: bits, groupsize, _, quant_method = self._get_gptq_params() - qzeros = self._get_qweight(f"{prefix}.qzeros", blocks) - scales = self._get_qweight(f"{prefix}.scales", blocks) + qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes) + scales = self._get_qweight(f"{prefix}.scales", block_sizes) scales = scales.to(dtype=self.dtype) if quantize == "gptq" and quant_method == "gptq": @@ -205,27 +217,31 @@ class Weights: elif quantize == "marlin": from text_generation_server.layers.marlin import MarlinWeight - B = self._get_qweight(f"{prefix}.B", blocks) - s = self._get_qweight(f"{prefix}.s", blocks) + B = self._get_qweight(f"{prefix}.B", block_sizes) + s = self._get_qweight(f"{prefix}.s", block_sizes) weight = MarlinWeight(B=B, s=s) else: slice_ = self._get_slice(f"{prefix}.weight") total_size = slice_.get_shape()[0] - assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}" - single_size = total_size // blocks + block_sizes = _blocks_to_block_sizes( + total_size=total_size, blocks=block_sizes + ) + world_size = self.process_group.size() rank = self.process_group.rank() - assert ( - single_size % world_size == 0 - ), f"Prepacked qkv cannot be sharded across {world_size} shards" - block_size = single_size // world_size - start = rank * block_size - stop = (rank + 1) * block_size tensors = [] - for i in range(blocks): - tensor = slice_[start + i * single_size : stop + i * single_size] + block_offset = 0 + for block_size in block_sizes: + assert ( + block_size % world_size == 0 + ), f"Prepacked weights cannot be sharded across {world_size} shards" + shard_block_size = block_size // world_size + start = rank * shard_block_size + stop = (rank + 1) * shard_block_size + tensor = slice_[block_offset + start : block_offset + stop] tensors.append(tensor) + block_offset += block_size weight = torch.cat(tensors, dim=0) weight = weight.to(device=self.device) weight = weight.to(dtype=self.dtype) @@ -593,3 +609,31 @@ class Weights: self.quant_method = "awq" except Exception: pass + + +def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: + """ + Convert block count or proportions to block sizes. + + This function accepts + + - The number of blocks (int), in which case the block size is + total_size//blocks; or + - A list of block sizes (List[int]). + + In the latter case, if sum(blocks) < total_size, the ratios between + the block sizes will be preserved. For instance, if blocks is + [2, 1, 1] and total_size is 1024, the returned block sizes are + [512, 256, 256]. + """ + if isinstance(blocks, list): + total_blocks = sum(blocks) + assert ( + total_size % total_blocks == 0 + ), f"Cannot split {total_size} in proportional blocks: {blocks}" + part_size = total_size // total_blocks + return [part_size * block for block in blocks] + else: + assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}" + single_size = total_size // blocks + return [single_size] * blocks