Merge branch 'main' into amd-ci-fx

This commit is contained in:
fxmarty 2024-06-10 15:10:04 +02:00
commit d3c7f63416
13 changed files with 153 additions and 187 deletions

View File

@ -481,6 +481,7 @@ fn shard_manager(
rope_factor: Option<f32>, rope_factor: Option<f32>,
max_total_tokens: usize, max_total_tokens: usize,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
max_input_tokens: usize,
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
log_level: LevelFilter, log_level: LevelFilter,
status_sender: mpsc::Sender<ShardStatus>, status_sender: mpsc::Sender<ShardStatus>,
@ -553,6 +554,10 @@ fn shard_manager(
shard_args.push(otlp_endpoint); shard_args.push(otlp_endpoint);
} }
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
shard_args.push("--max-input-tokens".to_string());
shard_args.push(max_input_tokens.to_string());
// Copy current process env // Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
@ -1009,6 +1014,7 @@ fn spawn_shards(
args: &Args, args: &Args,
cuda_graphs: Vec<usize>, cuda_graphs: Vec<usize>,
max_total_tokens: usize, max_total_tokens: usize,
max_input_tokens: usize,
max_log_level: LevelFilter, max_log_level: LevelFilter,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
@ -1066,6 +1072,7 @@ fn spawn_shards(
rope_factor, rope_factor,
max_total_tokens, max_total_tokens,
max_batch_size, max_batch_size,
max_input_tokens,
otlp_endpoint, otlp_endpoint,
max_log_level, max_log_level,
status_sender, status_sender,
@ -1540,6 +1547,7 @@ fn main() -> Result<(), LauncherError> {
&args, &args,
cuda_graphs, cuda_graphs,
max_total_tokens, max_total_tokens,
max_input_tokens,
max_log_level, max_log_level,
shutdown.clone(), shutdown.clone(),
&shutdown_receiver, &shutdown_receiver,

View File

@ -42,6 +42,7 @@ def serve(
logger_level: str = "INFO", logger_level: str = "INFO",
json_output: bool = False, json_output: bool = False,
otlp_endpoint: Optional[str] = None, otlp_endpoint: Optional[str] = None,
max_input_tokens: Optional[int] = None,
): ):
if sharded: if sharded:
assert ( assert (
@ -98,6 +99,7 @@ def serve(
dtype, dtype,
trust_remote_code, trust_remote_code,
uds_path, uds_path,
max_input_tokens,
) )

View File

@ -169,10 +169,8 @@ if ENGINE == "ck":
): ):
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1:
raise ValueError( # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q,
k, k,
@ -204,10 +202,7 @@ elif ENGINE == "triton":
window_size_left=-1, window_size_left=-1,
causal=True, causal=True,
): ):
if window_size_left != -1: # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
output, _ = triton_attention( output, _ = triton_attention(
q, q,
k, k,

View File

@ -14,10 +14,7 @@ def attention(
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
): ):
if window_size_left != -1: # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return ipex.llm.functional.varlen_attention( return ipex.llm.functional.varlen_attention(
q, q,
k, k,

View File

@ -129,9 +129,22 @@ class TensorParallelColumnLinear(SuperLayer):
return cls(linear) return cls(linear)
@classmethod @classmethod
def load_qkv(cls, config, prefix: str, weights, bias: bool): def load_qkv(
cls,
config,
prefix: str,
weights,
bias: bool,
num_heads: int,
num_key_value_heads: int,
):
"""Specific method when the QKV was joined after the fact""" """Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize) weight = weights.get_weights_col_packed_qkv(
prefix,
quantize=config.quantize,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
)
if bias: if bias:
raise NotImplementedError("packed_qkv only implemented for baichuan") raise NotImplementedError("packed_qkv only implemented for baichuan")
else: else:

View File

@ -24,6 +24,8 @@ from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.phi import Phi from text_generation_server.models.phi import Phi
from text_generation_server.utils.import_utils import SYSTEM
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later. # in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -257,6 +259,7 @@ def get_model(
speculate: Optional[int], speculate: Optional[int],
dtype: Optional[str], dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
max_input_tokens: int,
) -> Model: ) -> Model:
global FLASH_ATTENTION global FLASH_ATTENTION
if dtype is None: if dtype is None:
@ -410,11 +413,15 @@ def get_model(
"Sharding is currently not supported with `exl2` quantization" "Sharding is currently not supported with `exl2` quantization"
) )
sliding_window = config_dict.get("sliding_window", -1) sliding_window = config_dict.get("sliding_window", -1)
if sliding_window != -1 and not SUPPORTS_WINDOWING:
logger.warning( if (
f"Flash attention is available, but doesn't support windowing which is required by model {model_id}" (sliding_window is not None and sliding_window != -1)
and not SUPPORTS_WINDOWING
and max_input_tokens > sliding_window
):
raise ValueError(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
) )
FLASH_ATTENTION = False
if model_type == MAMBA: if model_type == MAMBA:
return Mamba( return Mamba(
@ -701,7 +708,6 @@ def get_model(
) )
if model_type == MISTRAL: if model_type == MISTRAL:
sliding_window = config_dict.get("sliding_window", -1)
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashMistral( return FlashMistral(
model_id, model_id,
@ -724,7 +730,6 @@ def get_model(
) )
if model_type == MIXTRAL: if model_type == MIXTRAL:
sliding_window = config_dict.get("sliding_window", -1)
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashMixtral( return FlashMixtral(
model_id, model_id,
@ -747,7 +752,6 @@ def get_model(
) )
if model_type == STARCODER2: if model_type == STARCODER2:
sliding_window = config_dict.get("sliding_window", -1)
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashStarcoder2( return FlashStarcoder2(
model_id, model_id,
@ -771,8 +775,7 @@ def get_model(
) )
if model_type == QWEN2: if model_type == QWEN2:
sliding_window = config_dict.get("sliding_window", -1) if FLASH_ATTENTION:
if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING:
return FlashQwen2( return FlashQwen2(
model_id, model_id,
revision, revision,

View File

@ -20,7 +20,6 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any from typing import Optional, List, Tuple, Any
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "xpu": if SYSTEM != "xpu":
@ -164,129 +163,13 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor:
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
if config.n_heads != config.attn_config.kv_n_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_qkv( return TensorParallelColumnLinear.load_qkv(
config, config,
prefix=f"{prefix}.Wqkv", prefix=f"{prefix}.Wqkv",
weights=weights, weights=weights,
bias=False, bias=False,
) num_heads=config.n_heads,
num_key_value_heads=config.attn_config.kv_n_heads,
def _load_gqa(config, prefix: str, weights):
assert config.d_model % config.n_heads == 0
assert config.n_heads % weights.process_group.size() == 0
head_dim = config.d_model // config.n_heads
world_size = weights.process_group.size()
rank = weights.process_group.rank()
q_block_size = config.d_model // world_size
q_start = rank * q_block_size
q_stop = (rank + 1) * q_block_size
kv_block_size = (config.attn_config.kv_n_heads * head_dim) // world_size
k_offset = config.d_model
k_start = k_offset + rank * kv_block_size
k_stop = k_offset + (rank + 1) * kv_block_size
v_offset = config.d_model + config.attn_config.kv_n_heads * head_dim
v_start = v_offset + rank * kv_block_size
v_stop = v_offset + (rank + 1) * kv_block_size
if config.quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight
try:
qweight_slice = weights._get_slice(f"{prefix}.qweight")
q_qweight = qweight_slice[:, q_start:q_stop]
k_qweight = qweight_slice[:, k_start:k_stop]
v_qweight = qweight_slice[:, v_start:v_stop]
qweight = torch.cat([q_qweight, k_qweight, v_qweight], dim=1)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{config.quantize}` weight, make sure the model is already quantized"
)
qzeros_slice = weights._get_slice(f"{prefix}.qzeros")
q_qzeros = qzeros_slice[:, q_start:q_stop]
k_qzeros = qzeros_slice[:, k_start:k_stop]
v_qzeros = qzeros_slice[:, v_start:v_stop]
qzeros = torch.cat([q_qzeros, k_qzeros, v_qzeros], dim=1)
scales_slice = weights._get_slice(f"{prefix}.scales")
q_scales = scales_slice[:, q_start:q_stop]
k_scales = scales_slice[:, k_start:k_stop]
v_scales = scales_slice[:, v_start:v_stop]
scales = torch.cat([q_scales, k_scales, v_scales], dim=1)
bits, groupsize, desc_act, quant_method = weights._get_gptq_params()
from text_generation_server.layers import HAS_EXLLAMA
use_exllama = (
bits == 4 and HAS_EXLLAMA and config.quantize == "gptq" and not desc_act
)
if config.quantize == "gptq" and quant_method == "gptq":
g_idx_slice = weights._get_slice(f"{prefix}.g_idx")
q_g_idx = g_idx_slice[:, q_start:q_stop]
k_g_idx = g_idx_slice[:, k_start:k_stop]
v_g_idx = g_idx_slice[:, v_start:v_stop]
w = [q_g_idx, k_g_idx, v_g_idx]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif config.quantize == "gptq" and quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conveersion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
// groupsize
).to(dtype=torch.int32)
else:
g_idx = None
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=use_exllama,
)
elif config.quantize == "marlin":
# NOTE: at the time marlin support was added, the only model that
# exists is LnL-AI/dbrx-base-converted-v2-4bit-gptq-marlin(-v2),
# but it requires manual concatenation of weight files.
raise RuntimeError("dbrx models with marlin quantization are not yet supported")
else:
qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight")
q = qkv_slice[q_start:q_stop]
k = qkv_slice[k_start:k_stop]
v = qkv_slice[v_start:v_stop]
weight = torch.cat([q, k, v], dim=0)
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
) )

View File

@ -59,7 +59,12 @@ def _load_qkv_gptq(config, prefix: str, weights):
rank = weights.process_group.rank() rank = weights.process_group.rank()
# Weights # Weights
weight = weights.get_weights_col_packed_qkv(f"{prefix}.c_attn", config.quantize) weight = weights.get_weights_col_packed_qkv(
f"{prefix}.c_attn",
config.quantize,
config.num_attention_heads,
config.num_attention_heads,
)
# Bias # Bias
slice_ = weights._get_slice(f"{prefix}.c_attn.bias") slice_ = weights._get_slice(f"{prefix}.c_attn.bias")

View File

@ -62,6 +62,8 @@ def load_attention(config, prefix, weights):
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
weights=weights, weights=weights,
bias=bias, bias=bias,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
) )
elif config.model_type == "baichuan": elif config.model_type == "baichuan":
return TensorParallelColumnLinear.load_qkv( return TensorParallelColumnLinear.load_qkv(
@ -69,6 +71,8 @@ def load_attention(config, prefix, weights):
prefix=f"{prefix}.W_pack", prefix=f"{prefix}.W_pack",
weights=weights, weights=weights,
bias=bias, bias=bias,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
) )
# otherwise, load the default attention based on the number of heads # otherwise, load the default attention based on the number of heads
@ -107,6 +111,11 @@ class FlashLlamaAttention(torch.nn.Module):
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}" f"and `num_shards`: {weights.process_group.size()}"
) )
if config.num_key_value_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = ( self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size() config.num_key_value_heads // weights.process_group.size()

View File

@ -902,6 +902,8 @@ class FlashCausalLM(Model):
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1" or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
): ):
torch.cuda.tunable.enable()
if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0": if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
torch.cuda.tunable.tuning_enable(True) torch.cuda.tunable.tuning_enable(True)
@ -910,8 +912,11 @@ class FlashCausalLM(Model):
int(val) int(val)
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",") for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
] ]
else: elif CUDA_GRAPHS is not None:
tuning_sequences = CUDA_GRAPHS tuning_sequences = CUDA_GRAPHS
else:
# For seqlen = 1, we dispatch to LLMM1 kernel.
tuning_sequences = [2, 3, 4, 5, 6, 7]
tunableop_filepath = os.path.join( tunableop_filepath = os.path.join(
HUGGINGFACE_HUB_CACHE, HUGGINGFACE_HUB_CACHE,

View File

@ -199,6 +199,7 @@ def serve(
dtype: Optional[str], dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
uds_path: Path, uds_path: Path,
max_input_tokens: int,
): ):
async def serve_inner( async def serve_inner(
model_id: str, model_id: str,
@ -229,6 +230,7 @@ def serve(
speculate, speculate,
dtype, dtype,
trust_remote_code, trust_remote_code,
max_input_tokens,
) )
except Exception: except Exception:
logger.exception("Error when initializing model") logger.exception("Error when initializing model")

View File

@ -1,7 +1,6 @@
from dataclasses import dataclass, field
import os import os
from pathlib import Path from pathlib import Path
from typing import List, Dict, Optional, Set, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from safetensors import safe_open, SafetensorError from safetensors import safe_open, SafetensorError
import torch import torch
from loguru import logger from loguru import logger
@ -121,49 +120,62 @@ class Weights:
), f"The choosen size {size} is not compatible with sharding on {world_size} shards" ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim) return self.get_partial_sharded(tensor_name, dim)
def _get_qweight(self, name: str, blocks: int): def _get_qweight(self, name: str, block_sizes: Union[int, List[int]]):
slice_ = self._get_slice(name) slice_ = self._get_slice(name)
total_size = slice_.get_shape()[1] total_size = slice_.get_shape()[1]
assert ( block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)
total_size % blocks == 0
), f"Prepacked quantized matrix is not divisible by {blocks}"
single_size = total_size // blocks
world_size = self.process_group.size() world_size = self.process_group.size()
rank = self.process_group.rank() rank = self.process_group.rank()
assert (
single_size % world_size == 0
), f"Prepacked quantized matrix cannot be sharded across {world_size} shards"
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
weights = [] weights = []
for block in range(blocks): block_offset = 0
weights.append( for block_size in block_sizes:
slice_[:, start + block * single_size : stop + block * single_size] assert (
) block_size % world_size == 0
), f"Prepacked qkv cannot be sharded across {world_size} shards"
shard_block_size = block_size // world_size
start = rank * shard_block_size
stop = (rank + 1) * shard_block_size
weights.append(slice_[:, block_offset + start : block_offset + stop])
block_offset += block_size
weight = torch.cat(weights, dim=1) weight = torch.cat(weights, dim=1)
weight = weight.to(device=self.device) weight = weight.to(device=self.device)
return weight return weight
def get_weights_col_packed_qkv(self, prefix: str, quantize: str): def get_weights_col_packed_qkv(
return self.get_weights_col_packed(prefix, quantize, 3) self,
prefix: str,
quantize: str,
num_heads: int,
num_key_value_heads: int,
):
return self.get_weights_col_packed(
prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads]
)
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str): def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
return self.get_weights_col_packed(prefix, quantize, 2) return self.get_weights_col_packed(prefix, quantize, 2)
def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int): def get_weights_col_packed(
self, prefix: str, quantize: str, block_sizes: Union[int, List[int]]
):
""" """
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
already alternating Q,K,V within the main tensor already alternating Q,K,V within the main tensor.
The columns are split in equally sized blocks when blocks is an `int`, or
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
convenient for e.g. splitting QKV without knowing the storage details of
quantized weights.
""" """
if quantize in ["gptq", "awq"]: if quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.gptq import GPTQWeight
try: try:
qweight = self._get_qweight(f"{prefix}.qweight", blocks) qweight = self._get_qweight(f"{prefix}.qweight", block_sizes)
except RuntimeError: except RuntimeError:
raise RuntimeError( raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized." f"Cannot load `{quantize}` weight, make sure the model is already quantized."
@ -171,8 +183,8 @@ class Weights:
bits, groupsize, _, quant_method = self._get_gptq_params() bits, groupsize, _, quant_method = self._get_gptq_params()
qzeros = self._get_qweight(f"{prefix}.qzeros", blocks) qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes)
scales = self._get_qweight(f"{prefix}.scales", blocks) scales = self._get_qweight(f"{prefix}.scales", block_sizes)
scales = scales.to(dtype=self.dtype) scales = scales.to(dtype=self.dtype)
if quantize == "gptq" and quant_method == "gptq": if quantize == "gptq" and quant_method == "gptq":
@ -205,27 +217,31 @@ class Weights:
elif quantize == "marlin": elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight from text_generation_server.layers.marlin import MarlinWeight
B = self._get_qweight(f"{prefix}.B", blocks) B = self._get_qweight(f"{prefix}.B", block_sizes)
s = self._get_qweight(f"{prefix}.s", blocks) s = self._get_qweight(f"{prefix}.s", block_sizes)
weight = MarlinWeight(B=B, s=s) weight = MarlinWeight(B=B, s=s)
else: else:
slice_ = self._get_slice(f"{prefix}.weight") slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0] total_size = slice_.get_shape()[0]
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}" block_sizes = _blocks_to_block_sizes(
single_size = total_size // blocks total_size=total_size, blocks=block_sizes
)
world_size = self.process_group.size() world_size = self.process_group.size()
rank = self.process_group.rank() rank = self.process_group.rank()
assert (
single_size % world_size == 0
), f"Prepacked qkv cannot be sharded across {world_size} shards"
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensors = [] tensors = []
for i in range(blocks): block_offset = 0
tensor = slice_[start + i * single_size : stop + i * single_size] for block_size in block_sizes:
assert (
block_size % world_size == 0
), f"Prepacked weights cannot be sharded across {world_size} shards"
shard_block_size = block_size // world_size
start = rank * shard_block_size
stop = (rank + 1) * shard_block_size
tensor = slice_[block_offset + start : block_offset + stop]
tensors.append(tensor) tensors.append(tensor)
block_offset += block_size
weight = torch.cat(tensors, dim=0) weight = torch.cat(tensors, dim=0)
weight = weight.to(device=self.device) weight = weight.to(device=self.device)
weight = weight.to(dtype=self.dtype) weight = weight.to(dtype=self.dtype)
@ -593,3 +609,31 @@ class Weights:
self.quant_method = "awq" self.quant_method = "awq"
except Exception: except Exception:
pass pass
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
"""
Convert block count or proportions to block sizes.
This function accepts
- The number of blocks (int), in which case the block size is
total_size//blocks; or
- A list of block sizes (List[int]).
In the latter case, if sum(blocks) < total_size, the ratios between
the block sizes will be preserved. For instance, if blocks is
[2, 1, 1] and total_size is 1024, the returned block sizes are
[512, 256, 256].
"""
if isinstance(blocks, list):
total_blocks = sum(blocks)
assert (
total_size % total_blocks == 0
), f"Cannot split {total_size} in proportional blocks: {blocks}"
part_size = total_size // total_blocks
return [part_size * block for block in blocks]
else:
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
single_size = total_size // blocks
return [single_size] * blocks