Merge branch 'main' into amd-ci-fx
This commit is contained in:
commit
d3c7f63416
|
@ -481,6 +481,7 @@ fn shard_manager(
|
|||
rope_factor: Option<f32>,
|
||||
max_total_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
max_input_tokens: usize,
|
||||
otlp_endpoint: Option<String>,
|
||||
log_level: LevelFilter,
|
||||
status_sender: mpsc::Sender<ShardStatus>,
|
||||
|
@ -553,6 +554,10 @@ fn shard_manager(
|
|||
shard_args.push(otlp_endpoint);
|
||||
}
|
||||
|
||||
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
|
||||
shard_args.push("--max-input-tokens".to_string());
|
||||
shard_args.push(max_input_tokens.to_string());
|
||||
|
||||
// Copy current process env
|
||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||
|
||||
|
@ -1009,6 +1014,7 @@ fn spawn_shards(
|
|||
args: &Args,
|
||||
cuda_graphs: Vec<usize>,
|
||||
max_total_tokens: usize,
|
||||
max_input_tokens: usize,
|
||||
max_log_level: LevelFilter,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
shutdown_receiver: &mpsc::Receiver<()>,
|
||||
|
@ -1066,6 +1072,7 @@ fn spawn_shards(
|
|||
rope_factor,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
max_input_tokens,
|
||||
otlp_endpoint,
|
||||
max_log_level,
|
||||
status_sender,
|
||||
|
@ -1540,6 +1547,7 @@ fn main() -> Result<(), LauncherError> {
|
|||
&args,
|
||||
cuda_graphs,
|
||||
max_total_tokens,
|
||||
max_input_tokens,
|
||||
max_log_level,
|
||||
shutdown.clone(),
|
||||
&shutdown_receiver,
|
||||
|
|
|
@ -42,6 +42,7 @@ def serve(
|
|||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
otlp_endpoint: Optional[str] = None,
|
||||
max_input_tokens: Optional[int] = None,
|
||||
):
|
||||
if sharded:
|
||||
assert (
|
||||
|
@ -98,6 +99,7 @@ def serve(
|
|||
dtype,
|
||||
trust_remote_code,
|
||||
uds_path,
|
||||
max_input_tokens,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -169,10 +169,8 @@ if ENGINE == "ck":
|
|||
):
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
if window_size_left != -1:
|
||||
raise ValueError(
|
||||
f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||
)
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
|
@ -204,10 +202,7 @@ elif ENGINE == "triton":
|
|||
window_size_left=-1,
|
||||
causal=True,
|
||||
):
|
||||
if window_size_left != -1:
|
||||
raise ValueError(
|
||||
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||
)
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
output, _ = triton_attention(
|
||||
q,
|
||||
k,
|
||||
|
|
|
@ -14,10 +14,7 @@ def attention(
|
|||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
):
|
||||
if window_size_left != -1:
|
||||
raise ValueError(
|
||||
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||
)
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
return ipex.llm.functional.varlen_attention(
|
||||
q,
|
||||
k,
|
||||
|
|
|
@ -129,9 +129,22 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||
return cls(linear)
|
||||
|
||||
@classmethod
|
||||
def load_qkv(cls, config, prefix: str, weights, bias: bool):
|
||||
def load_qkv(
|
||||
cls,
|
||||
config,
|
||||
prefix: str,
|
||||
weights,
|
||||
bias: bool,
|
||||
num_heads: int,
|
||||
num_key_value_heads: int,
|
||||
):
|
||||
"""Specific method when the QKV was joined after the fact"""
|
||||
weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize)
|
||||
weight = weights.get_weights_col_packed_qkv(
|
||||
prefix,
|
||||
quantize=config.quantize,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
)
|
||||
if bias:
|
||||
raise NotImplementedError("packed_qkv only implemented for baichuan")
|
||||
else:
|
||||
|
|
|
@ -24,6 +24,8 @@ from text_generation_server.models.t5 import T5Sharded
|
|||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||
from text_generation_server.models.phi import Phi
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||
# in PyTorch 1.12 and later.
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
@ -257,6 +259,7 @@ def get_model(
|
|||
speculate: Optional[int],
|
||||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
max_input_tokens: int,
|
||||
) -> Model:
|
||||
global FLASH_ATTENTION
|
||||
if dtype is None:
|
||||
|
@ -410,11 +413,15 @@ def get_model(
|
|||
"Sharding is currently not supported with `exl2` quantization"
|
||||
)
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if sliding_window != -1 and not SUPPORTS_WINDOWING:
|
||||
logger.warning(
|
||||
f"Flash attention is available, but doesn't support windowing which is required by model {model_id}"
|
||||
|
||||
if (
|
||||
(sliding_window is not None and sliding_window != -1)
|
||||
and not SUPPORTS_WINDOWING
|
||||
and max_input_tokens > sliding_window
|
||||
):
|
||||
raise ValueError(
|
||||
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
||||
)
|
||||
FLASH_ATTENTION = False
|
||||
|
||||
if model_type == MAMBA:
|
||||
return Mamba(
|
||||
|
@ -701,7 +708,6 @@ def get_model(
|
|||
)
|
||||
|
||||
if model_type == MISTRAL:
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if FLASH_ATTENTION:
|
||||
return FlashMistral(
|
||||
model_id,
|
||||
|
@ -724,7 +730,6 @@ def get_model(
|
|||
)
|
||||
|
||||
if model_type == MIXTRAL:
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if FLASH_ATTENTION:
|
||||
return FlashMixtral(
|
||||
model_id,
|
||||
|
@ -747,7 +752,6 @@ def get_model(
|
|||
)
|
||||
|
||||
if model_type == STARCODER2:
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if FLASH_ATTENTION:
|
||||
return FlashStarcoder2(
|
||||
model_id,
|
||||
|
@ -771,8 +775,7 @@ def get_model(
|
|||
)
|
||||
|
||||
if model_type == QWEN2:
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashQwen2(
|
||||
model_id,
|
||||
revision,
|
||||
|
|
|
@ -20,7 +20,6 @@ from torch import nn
|
|||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM != "xpu":
|
||||
|
@ -164,129 +163,13 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor:
|
|||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.n_heads != config.attn_config.kv_n_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.Wqkv",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
assert config.d_model % config.n_heads == 0
|
||||
assert config.n_heads % weights.process_group.size() == 0
|
||||
|
||||
head_dim = config.d_model // config.n_heads
|
||||
world_size = weights.process_group.size()
|
||||
rank = weights.process_group.rank()
|
||||
|
||||
q_block_size = config.d_model // world_size
|
||||
q_start = rank * q_block_size
|
||||
q_stop = (rank + 1) * q_block_size
|
||||
|
||||
kv_block_size = (config.attn_config.kv_n_heads * head_dim) // world_size
|
||||
k_offset = config.d_model
|
||||
k_start = k_offset + rank * kv_block_size
|
||||
k_stop = k_offset + (rank + 1) * kv_block_size
|
||||
|
||||
v_offset = config.d_model + config.attn_config.kv_n_heads * head_dim
|
||||
v_start = v_offset + rank * kv_block_size
|
||||
v_stop = v_offset + (rank + 1) * kv_block_size
|
||||
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
|
||||
try:
|
||||
qweight_slice = weights._get_slice(f"{prefix}.qweight")
|
||||
q_qweight = qweight_slice[:, q_start:q_stop]
|
||||
k_qweight = qweight_slice[:, k_start:k_stop]
|
||||
v_qweight = qweight_slice[:, v_start:v_stop]
|
||||
|
||||
qweight = torch.cat([q_qweight, k_qweight, v_qweight], dim=1)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{config.quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
qzeros_slice = weights._get_slice(f"{prefix}.qzeros")
|
||||
q_qzeros = qzeros_slice[:, q_start:q_stop]
|
||||
k_qzeros = qzeros_slice[:, k_start:k_stop]
|
||||
v_qzeros = qzeros_slice[:, v_start:v_stop]
|
||||
|
||||
qzeros = torch.cat([q_qzeros, k_qzeros, v_qzeros], dim=1)
|
||||
|
||||
scales_slice = weights._get_slice(f"{prefix}.scales")
|
||||
q_scales = scales_slice[:, q_start:q_stop]
|
||||
k_scales = scales_slice[:, k_start:k_stop]
|
||||
v_scales = scales_slice[:, v_start:v_stop]
|
||||
|
||||
scales = torch.cat([q_scales, k_scales, v_scales], dim=1)
|
||||
|
||||
bits, groupsize, desc_act, quant_method = weights._get_gptq_params()
|
||||
|
||||
from text_generation_server.layers import HAS_EXLLAMA
|
||||
|
||||
use_exllama = (
|
||||
bits == 4 and HAS_EXLLAMA and config.quantize == "gptq" and not desc_act
|
||||
)
|
||||
|
||||
if config.quantize == "gptq" and quant_method == "gptq":
|
||||
g_idx_slice = weights._get_slice(f"{prefix}.g_idx")
|
||||
q_g_idx = g_idx_slice[:, q_start:q_stop]
|
||||
k_g_idx = g_idx_slice[:, k_start:k_stop]
|
||||
v_g_idx = g_idx_slice[:, v_start:v_stop]
|
||||
|
||||
w = [q_g_idx, k_g_idx, v_g_idx]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
elif config.quantize == "gptq" and quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.layers.awq.conveersion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
if use_exllama:
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = (
|
||||
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
|
||||
// groupsize
|
||||
).to(dtype=torch.int32)
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
weight = GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=bits,
|
||||
groupsize=groupsize,
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
elif config.quantize == "marlin":
|
||||
# NOTE: at the time marlin support was added, the only model that
|
||||
# exists is LnL-AI/dbrx-base-converted-v2-4bit-gptq-marlin(-v2),
|
||||
# but it requires manual concatenation of weight files.
|
||||
raise RuntimeError("dbrx models with marlin quantization are not yet supported")
|
||||
else:
|
||||
qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight")
|
||||
q = qkv_slice[q_start:q_stop]
|
||||
k = qkv_slice[k_start:k_stop]
|
||||
v = qkv_slice[v_start:v_stop]
|
||||
|
||||
weight = torch.cat([q, k, v], dim=0)
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=None, quantize=config.quantize)
|
||||
num_heads=config.n_heads,
|
||||
num_key_value_heads=config.attn_config.kv_n_heads,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -59,7 +59,12 @@ def _load_qkv_gptq(config, prefix: str, weights):
|
|||
rank = weights.process_group.rank()
|
||||
|
||||
# Weights
|
||||
weight = weights.get_weights_col_packed_qkv(f"{prefix}.c_attn", config.quantize)
|
||||
weight = weights.get_weights_col_packed_qkv(
|
||||
f"{prefix}.c_attn",
|
||||
config.quantize,
|
||||
config.num_attention_heads,
|
||||
config.num_attention_heads,
|
||||
)
|
||||
|
||||
# Bias
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
||||
|
|
|
@ -62,6 +62,8 @@ def load_attention(config, prefix, weights):
|
|||
prefix=f"{prefix}.qkv_proj",
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
)
|
||||
elif config.model_type == "baichuan":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
|
@ -69,6 +71,8 @@ def load_attention(config, prefix, weights):
|
|||
prefix=f"{prefix}.W_pack",
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
)
|
||||
|
||||
# otherwise, load the default attention based on the number of heads
|
||||
|
@ -107,6 +111,11 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
if config.num_key_value_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
|
|
|
@ -902,6 +902,8 @@ class FlashCausalLM(Model):
|
|||
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
|
||||
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
|
||||
):
|
||||
torch.cuda.tunable.enable()
|
||||
|
||||
if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
|
||||
torch.cuda.tunable.tuning_enable(True)
|
||||
|
||||
|
@ -910,8 +912,11 @@ class FlashCausalLM(Model):
|
|||
int(val)
|
||||
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
|
||||
]
|
||||
else:
|
||||
elif CUDA_GRAPHS is not None:
|
||||
tuning_sequences = CUDA_GRAPHS
|
||||
else:
|
||||
# For seqlen = 1, we dispatch to LLMM1 kernel.
|
||||
tuning_sequences = [2, 3, 4, 5, 6, 7]
|
||||
|
||||
tunableop_filepath = os.path.join(
|
||||
HUGGINGFACE_HUB_CACHE,
|
||||
|
|
|
@ -199,6 +199,7 @@ def serve(
|
|||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
uds_path: Path,
|
||||
max_input_tokens: int,
|
||||
):
|
||||
async def serve_inner(
|
||||
model_id: str,
|
||||
|
@ -229,6 +230,7 @@ def serve(
|
|||
speculate,
|
||||
dtype,
|
||||
trust_remote_code,
|
||||
max_input_tokens,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error when initializing model")
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from dataclasses import dataclass, field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Set, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from safetensors import safe_open, SafetensorError
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
@ -121,49 +120,62 @@ class Weights:
|
|||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||
return self.get_partial_sharded(tensor_name, dim)
|
||||
|
||||
def _get_qweight(self, name: str, blocks: int):
|
||||
def _get_qweight(self, name: str, block_sizes: Union[int, List[int]]):
|
||||
slice_ = self._get_slice(name)
|
||||
total_size = slice_.get_shape()[1]
|
||||
assert (
|
||||
total_size % blocks == 0
|
||||
), f"Prepacked quantized matrix is not divisible by {blocks}"
|
||||
single_size = total_size // blocks
|
||||
block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)
|
||||
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
assert (
|
||||
single_size % world_size == 0
|
||||
), f"Prepacked quantized matrix cannot be sharded across {world_size} shards"
|
||||
block_size = single_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
|
||||
weights = []
|
||||
for block in range(blocks):
|
||||
weights.append(
|
||||
slice_[:, start + block * single_size : stop + block * single_size]
|
||||
)
|
||||
block_offset = 0
|
||||
for block_size in block_sizes:
|
||||
assert (
|
||||
block_size % world_size == 0
|
||||
), f"Prepacked qkv cannot be sharded across {world_size} shards"
|
||||
shard_block_size = block_size // world_size
|
||||
start = rank * shard_block_size
|
||||
stop = (rank + 1) * shard_block_size
|
||||
weights.append(slice_[:, block_offset + start : block_offset + stop])
|
||||
block_offset += block_size
|
||||
|
||||
weight = torch.cat(weights, dim=1)
|
||||
weight = weight.to(device=self.device)
|
||||
return weight
|
||||
|
||||
def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
|
||||
return self.get_weights_col_packed(prefix, quantize, 3)
|
||||
def get_weights_col_packed_qkv(
|
||||
self,
|
||||
prefix: str,
|
||||
quantize: str,
|
||||
num_heads: int,
|
||||
num_key_value_heads: int,
|
||||
):
|
||||
return self.get_weights_col_packed(
|
||||
prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads]
|
||||
)
|
||||
|
||||
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
|
||||
return self.get_weights_col_packed(prefix, quantize, 2)
|
||||
|
||||
def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int):
|
||||
def get_weights_col_packed(
|
||||
self, prefix: str, quantize: str, block_sizes: Union[int, List[int]]
|
||||
):
|
||||
"""
|
||||
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
||||
already alternating Q,K,V within the main tensor
|
||||
already alternating Q,K,V within the main tensor.
|
||||
|
||||
The columns are split in equally sized blocks when blocks is an `int`, or
|
||||
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
|
||||
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
|
||||
convenient for e.g. splitting QKV without knowing the storage details of
|
||||
quantized weights.
|
||||
"""
|
||||
if quantize in ["gptq", "awq"]:
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
|
||||
try:
|
||||
qweight = self._get_qweight(f"{prefix}.qweight", blocks)
|
||||
qweight = self._get_qweight(f"{prefix}.qweight", block_sizes)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
||||
|
@ -171,8 +183,8 @@ class Weights:
|
|||
|
||||
bits, groupsize, _, quant_method = self._get_gptq_params()
|
||||
|
||||
qzeros = self._get_qweight(f"{prefix}.qzeros", blocks)
|
||||
scales = self._get_qweight(f"{prefix}.scales", blocks)
|
||||
qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes)
|
||||
scales = self._get_qweight(f"{prefix}.scales", block_sizes)
|
||||
scales = scales.to(dtype=self.dtype)
|
||||
|
||||
if quantize == "gptq" and quant_method == "gptq":
|
||||
|
@ -205,27 +217,31 @@ class Weights:
|
|||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.marlin import MarlinWeight
|
||||
|
||||
B = self._get_qweight(f"{prefix}.B", blocks)
|
||||
s = self._get_qweight(f"{prefix}.s", blocks)
|
||||
B = self._get_qweight(f"{prefix}.B", block_sizes)
|
||||
s = self._get_qweight(f"{prefix}.s", block_sizes)
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
else:
|
||||
slice_ = self._get_slice(f"{prefix}.weight")
|
||||
total_size = slice_.get_shape()[0]
|
||||
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
|
||||
single_size = total_size // blocks
|
||||
block_sizes = _blocks_to_block_sizes(
|
||||
total_size=total_size, blocks=block_sizes
|
||||
)
|
||||
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
assert (
|
||||
single_size % world_size == 0
|
||||
), f"Prepacked qkv cannot be sharded across {world_size} shards"
|
||||
block_size = single_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensors = []
|
||||
for i in range(blocks):
|
||||
tensor = slice_[start + i * single_size : stop + i * single_size]
|
||||
block_offset = 0
|
||||
for block_size in block_sizes:
|
||||
assert (
|
||||
block_size % world_size == 0
|
||||
), f"Prepacked weights cannot be sharded across {world_size} shards"
|
||||
shard_block_size = block_size // world_size
|
||||
start = rank * shard_block_size
|
||||
stop = (rank + 1) * shard_block_size
|
||||
tensor = slice_[block_offset + start : block_offset + stop]
|
||||
tensors.append(tensor)
|
||||
block_offset += block_size
|
||||
weight = torch.cat(tensors, dim=0)
|
||||
weight = weight.to(device=self.device)
|
||||
weight = weight.to(dtype=self.dtype)
|
||||
|
@ -593,3 +609,31 @@ class Weights:
|
|||
self.quant_method = "awq"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
||||
"""
|
||||
Convert block count or proportions to block sizes.
|
||||
|
||||
This function accepts
|
||||
|
||||
- The number of blocks (int), in which case the block size is
|
||||
total_size//blocks; or
|
||||
- A list of block sizes (List[int]).
|
||||
|
||||
In the latter case, if sum(blocks) < total_size, the ratios between
|
||||
the block sizes will be preserved. For instance, if blocks is
|
||||
[2, 1, 1] and total_size is 1024, the returned block sizes are
|
||||
[512, 256, 256].
|
||||
"""
|
||||
if isinstance(blocks, list):
|
||||
total_blocks = sum(blocks)
|
||||
assert (
|
||||
total_size % total_blocks == 0
|
||||
), f"Cannot split {total_size} in proportional blocks: {blocks}"
|
||||
part_size = total_size // total_blocks
|
||||
return [part_size * block for block in blocks]
|
||||
else:
|
||||
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
|
||||
single_size = total_size // blocks
|
||||
return [single_size] * blocks
|
||||
|
|
Loading…
Reference in New Issue