Merge branch 'main' into amd-ci-fx
This commit is contained in:
commit
d3c7f63416
|
@ -481,6 +481,7 @@ fn shard_manager(
|
||||||
rope_factor: Option<f32>,
|
rope_factor: Option<f32>,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
|
max_input_tokens: usize,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
log_level: LevelFilter,
|
log_level: LevelFilter,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
|
@ -553,6 +554,10 @@ fn shard_manager(
|
||||||
shard_args.push(otlp_endpoint);
|
shard_args.push(otlp_endpoint);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
|
||||||
|
shard_args.push("--max-input-tokens".to_string());
|
||||||
|
shard_args.push(max_input_tokens.to_string());
|
||||||
|
|
||||||
// Copy current process env
|
// Copy current process env
|
||||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||||
|
|
||||||
|
@ -1009,6 +1014,7 @@ fn spawn_shards(
|
||||||
args: &Args,
|
args: &Args,
|
||||||
cuda_graphs: Vec<usize>,
|
cuda_graphs: Vec<usize>,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
|
max_input_tokens: usize,
|
||||||
max_log_level: LevelFilter,
|
max_log_level: LevelFilter,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
shutdown_receiver: &mpsc::Receiver<()>,
|
shutdown_receiver: &mpsc::Receiver<()>,
|
||||||
|
@ -1066,6 +1072,7 @@ fn spawn_shards(
|
||||||
rope_factor,
|
rope_factor,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
|
max_input_tokens,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
max_log_level,
|
max_log_level,
|
||||||
status_sender,
|
status_sender,
|
||||||
|
@ -1540,6 +1547,7 @@ fn main() -> Result<(), LauncherError> {
|
||||||
&args,
|
&args,
|
||||||
cuda_graphs,
|
cuda_graphs,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
max_input_tokens,
|
||||||
max_log_level,
|
max_log_level,
|
||||||
shutdown.clone(),
|
shutdown.clone(),
|
||||||
&shutdown_receiver,
|
&shutdown_receiver,
|
||||||
|
|
|
@ -42,6 +42,7 @@ def serve(
|
||||||
logger_level: str = "INFO",
|
logger_level: str = "INFO",
|
||||||
json_output: bool = False,
|
json_output: bool = False,
|
||||||
otlp_endpoint: Optional[str] = None,
|
otlp_endpoint: Optional[str] = None,
|
||||||
|
max_input_tokens: Optional[int] = None,
|
||||||
):
|
):
|
||||||
if sharded:
|
if sharded:
|
||||||
assert (
|
assert (
|
||||||
|
@ -98,6 +99,7 @@ def serve(
|
||||||
dtype,
|
dtype,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
uds_path,
|
uds_path,
|
||||||
|
max_input_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -169,10 +169,8 @@ if ENGINE == "ck":
|
||||||
):
|
):
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
if window_size_left != -1:
|
|
||||||
raise ValueError(
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
||||||
)
|
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
|
@ -204,10 +202,7 @@ elif ENGINE == "triton":
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
causal=True,
|
causal=True,
|
||||||
):
|
):
|
||||||
if window_size_left != -1:
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
raise ValueError(
|
|
||||||
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
||||||
)
|
|
||||||
output, _ = triton_attention(
|
output, _ = triton_attention(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
|
|
|
@ -14,10 +14,7 @@ def attention(
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
):
|
):
|
||||||
if window_size_left != -1:
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
raise ValueError(
|
|
||||||
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
||||||
)
|
|
||||||
return ipex.llm.functional.varlen_attention(
|
return ipex.llm.functional.varlen_attention(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
|
|
|
@ -129,9 +129,22 @@ class TensorParallelColumnLinear(SuperLayer):
|
||||||
return cls(linear)
|
return cls(linear)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_qkv(cls, config, prefix: str, weights, bias: bool):
|
def load_qkv(
|
||||||
|
cls,
|
||||||
|
config,
|
||||||
|
prefix: str,
|
||||||
|
weights,
|
||||||
|
bias: bool,
|
||||||
|
num_heads: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
):
|
||||||
"""Specific method when the QKV was joined after the fact"""
|
"""Specific method when the QKV was joined after the fact"""
|
||||||
weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize)
|
weight = weights.get_weights_col_packed_qkv(
|
||||||
|
prefix,
|
||||||
|
quantize=config.quantize,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
)
|
||||||
if bias:
|
if bias:
|
||||||
raise NotImplementedError("packed_qkv only implemented for baichuan")
|
raise NotImplementedError("packed_qkv only implemented for baichuan")
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -24,6 +24,8 @@ from text_generation_server.models.t5 import T5Sharded
|
||||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||||
from text_generation_server.models.phi import Phi
|
from text_generation_server.models.phi import Phi
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||||
# in PyTorch 1.12 and later.
|
# in PyTorch 1.12 and later.
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
@ -257,6 +259,7 @@ def get_model(
|
||||||
speculate: Optional[int],
|
speculate: Optional[int],
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
|
max_input_tokens: int,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
global FLASH_ATTENTION
|
global FLASH_ATTENTION
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
|
@ -410,11 +413,15 @@ def get_model(
|
||||||
"Sharding is currently not supported with `exl2` quantization"
|
"Sharding is currently not supported with `exl2` quantization"
|
||||||
)
|
)
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
if sliding_window != -1 and not SUPPORTS_WINDOWING:
|
|
||||||
logger.warning(
|
if (
|
||||||
f"Flash attention is available, but doesn't support windowing which is required by model {model_id}"
|
(sliding_window is not None and sliding_window != -1)
|
||||||
|
and not SUPPORTS_WINDOWING
|
||||||
|
and max_input_tokens > sliding_window
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
||||||
)
|
)
|
||||||
FLASH_ATTENTION = False
|
|
||||||
|
|
||||||
if model_type == MAMBA:
|
if model_type == MAMBA:
|
||||||
return Mamba(
|
return Mamba(
|
||||||
|
@ -701,7 +708,6 @@ def get_model(
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == MISTRAL:
|
if model_type == MISTRAL:
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashMistral(
|
return FlashMistral(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -724,7 +730,6 @@ def get_model(
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == MIXTRAL:
|
if model_type == MIXTRAL:
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashMixtral(
|
return FlashMixtral(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -747,7 +752,6 @@ def get_model(
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == STARCODER2:
|
if model_type == STARCODER2:
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashStarcoder2(
|
return FlashStarcoder2(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -771,8 +775,7 @@ def get_model(
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == QWEN2:
|
if model_type == QWEN2:
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
if FLASH_ATTENTION:
|
||||||
if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING:
|
|
||||||
return FlashQwen2(
|
return FlashQwen2(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
|
|
@ -20,7 +20,6 @@ from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple, Any
|
from typing import Optional, List, Tuple, Any
|
||||||
from loguru import logger
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if SYSTEM != "xpu":
|
if SYSTEM != "xpu":
|
||||||
|
@ -164,129 +163,13 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
if config.n_heads != config.attn_config.kv_n_heads:
|
|
||||||
return _load_gqa(config, prefix, weights)
|
|
||||||
else:
|
|
||||||
return TensorParallelColumnLinear.load_qkv(
|
return TensorParallelColumnLinear.load_qkv(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.Wqkv",
|
prefix=f"{prefix}.Wqkv",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
num_heads=config.n_heads,
|
||||||
|
num_key_value_heads=config.attn_config.kv_n_heads,
|
||||||
|
|
||||||
def _load_gqa(config, prefix: str, weights):
|
|
||||||
assert config.d_model % config.n_heads == 0
|
|
||||||
assert config.n_heads % weights.process_group.size() == 0
|
|
||||||
|
|
||||||
head_dim = config.d_model // config.n_heads
|
|
||||||
world_size = weights.process_group.size()
|
|
||||||
rank = weights.process_group.rank()
|
|
||||||
|
|
||||||
q_block_size = config.d_model // world_size
|
|
||||||
q_start = rank * q_block_size
|
|
||||||
q_stop = (rank + 1) * q_block_size
|
|
||||||
|
|
||||||
kv_block_size = (config.attn_config.kv_n_heads * head_dim) // world_size
|
|
||||||
k_offset = config.d_model
|
|
||||||
k_start = k_offset + rank * kv_block_size
|
|
||||||
k_stop = k_offset + (rank + 1) * kv_block_size
|
|
||||||
|
|
||||||
v_offset = config.d_model + config.attn_config.kv_n_heads * head_dim
|
|
||||||
v_start = v_offset + rank * kv_block_size
|
|
||||||
v_stop = v_offset + (rank + 1) * kv_block_size
|
|
||||||
|
|
||||||
if config.quantize in ["gptq", "awq"]:
|
|
||||||
from text_generation_server.layers.gptq import GPTQWeight
|
|
||||||
|
|
||||||
try:
|
|
||||||
qweight_slice = weights._get_slice(f"{prefix}.qweight")
|
|
||||||
q_qweight = qweight_slice[:, q_start:q_stop]
|
|
||||||
k_qweight = qweight_slice[:, k_start:k_stop]
|
|
||||||
v_qweight = qweight_slice[:, v_start:v_stop]
|
|
||||||
|
|
||||||
qweight = torch.cat([q_qweight, k_qweight, v_qweight], dim=1)
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot load `{config.quantize}` weight, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
|
|
||||||
qzeros_slice = weights._get_slice(f"{prefix}.qzeros")
|
|
||||||
q_qzeros = qzeros_slice[:, q_start:q_stop]
|
|
||||||
k_qzeros = qzeros_slice[:, k_start:k_stop]
|
|
||||||
v_qzeros = qzeros_slice[:, v_start:v_stop]
|
|
||||||
|
|
||||||
qzeros = torch.cat([q_qzeros, k_qzeros, v_qzeros], dim=1)
|
|
||||||
|
|
||||||
scales_slice = weights._get_slice(f"{prefix}.scales")
|
|
||||||
q_scales = scales_slice[:, q_start:q_stop]
|
|
||||||
k_scales = scales_slice[:, k_start:k_stop]
|
|
||||||
v_scales = scales_slice[:, v_start:v_stop]
|
|
||||||
|
|
||||||
scales = torch.cat([q_scales, k_scales, v_scales], dim=1)
|
|
||||||
|
|
||||||
bits, groupsize, desc_act, quant_method = weights._get_gptq_params()
|
|
||||||
|
|
||||||
from text_generation_server.layers import HAS_EXLLAMA
|
|
||||||
|
|
||||||
use_exllama = (
|
|
||||||
bits == 4 and HAS_EXLLAMA and config.quantize == "gptq" and not desc_act
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.quantize == "gptq" and quant_method == "gptq":
|
|
||||||
g_idx_slice = weights._get_slice(f"{prefix}.g_idx")
|
|
||||||
q_g_idx = g_idx_slice[:, q_start:q_stop]
|
|
||||||
k_g_idx = g_idx_slice[:, k_start:k_stop]
|
|
||||||
v_g_idx = g_idx_slice[:, v_start:v_stop]
|
|
||||||
|
|
||||||
w = [q_g_idx, k_g_idx, v_g_idx]
|
|
||||||
for w2 in w[1:]:
|
|
||||||
torch.testing.assert_close(w2, w[0])
|
|
||||||
g_idx = w[0]
|
|
||||||
elif config.quantize == "gptq" and quant_method == "awq":
|
|
||||||
log_once(
|
|
||||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
|
||||||
)
|
|
||||||
from text_generation_server.layers.awq.conveersion_utils import (
|
|
||||||
fast_awq_to_gptq,
|
|
||||||
)
|
|
||||||
|
|
||||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
|
||||||
if use_exllama:
|
|
||||||
g_idx = None
|
|
||||||
else:
|
|
||||||
g_idx = (
|
|
||||||
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
|
|
||||||
// groupsize
|
|
||||||
).to(dtype=torch.int32)
|
|
||||||
else:
|
|
||||||
g_idx = None
|
|
||||||
|
|
||||||
weight = GPTQWeight(
|
|
||||||
qweight=qweight,
|
|
||||||
qzeros=qzeros,
|
|
||||||
scales=scales,
|
|
||||||
g_idx=g_idx,
|
|
||||||
bits=bits,
|
|
||||||
groupsize=groupsize,
|
|
||||||
use_exllama=use_exllama,
|
|
||||||
)
|
|
||||||
elif config.quantize == "marlin":
|
|
||||||
# NOTE: at the time marlin support was added, the only model that
|
|
||||||
# exists is LnL-AI/dbrx-base-converted-v2-4bit-gptq-marlin(-v2),
|
|
||||||
# but it requires manual concatenation of weight files.
|
|
||||||
raise RuntimeError("dbrx models with marlin quantization are not yet supported")
|
|
||||||
else:
|
|
||||||
qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight")
|
|
||||||
q = qkv_slice[q_start:q_stop]
|
|
||||||
k = qkv_slice[k_start:k_stop]
|
|
||||||
v = qkv_slice[v_start:v_stop]
|
|
||||||
|
|
||||||
weight = torch.cat([q, k, v], dim=0)
|
|
||||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
|
||||||
|
|
||||||
return TensorParallelColumnLinear(
|
|
||||||
get_linear(weight, bias=None, quantize=config.quantize)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -59,7 +59,12 @@ def _load_qkv_gptq(config, prefix: str, weights):
|
||||||
rank = weights.process_group.rank()
|
rank = weights.process_group.rank()
|
||||||
|
|
||||||
# Weights
|
# Weights
|
||||||
weight = weights.get_weights_col_packed_qkv(f"{prefix}.c_attn", config.quantize)
|
weight = weights.get_weights_col_packed_qkv(
|
||||||
|
f"{prefix}.c_attn",
|
||||||
|
config.quantize,
|
||||||
|
config.num_attention_heads,
|
||||||
|
config.num_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
# Bias
|
# Bias
|
||||||
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
||||||
|
|
|
@ -62,6 +62,8 @@ def load_attention(config, prefix, weights):
|
||||||
prefix=f"{prefix}.qkv_proj",
|
prefix=f"{prefix}.qkv_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
)
|
)
|
||||||
elif config.model_type == "baichuan":
|
elif config.model_type == "baichuan":
|
||||||
return TensorParallelColumnLinear.load_qkv(
|
return TensorParallelColumnLinear.load_qkv(
|
||||||
|
@ -69,6 +71,8 @@ def load_attention(config, prefix, weights):
|
||||||
prefix=f"{prefix}.W_pack",
|
prefix=f"{prefix}.W_pack",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
)
|
)
|
||||||
|
|
||||||
# otherwise, load the default attention based on the number of heads
|
# otherwise, load the default attention based on the number of heads
|
||||||
|
@ -107,6 +111,11 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
f"and `num_shards`: {weights.process_group.size()}"
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
)
|
)
|
||||||
|
if config.num_key_value_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
self.num_key_value_heads = (
|
self.num_key_value_heads = (
|
||||||
config.num_key_value_heads // weights.process_group.size()
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
|
|
@ -902,6 +902,8 @@ class FlashCausalLM(Model):
|
||||||
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
|
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
|
||||||
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
|
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
|
||||||
):
|
):
|
||||||
|
torch.cuda.tunable.enable()
|
||||||
|
|
||||||
if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
|
if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
|
||||||
torch.cuda.tunable.tuning_enable(True)
|
torch.cuda.tunable.tuning_enable(True)
|
||||||
|
|
||||||
|
@ -910,8 +912,11 @@ class FlashCausalLM(Model):
|
||||||
int(val)
|
int(val)
|
||||||
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
|
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
|
||||||
]
|
]
|
||||||
else:
|
elif CUDA_GRAPHS is not None:
|
||||||
tuning_sequences = CUDA_GRAPHS
|
tuning_sequences = CUDA_GRAPHS
|
||||||
|
else:
|
||||||
|
# For seqlen = 1, we dispatch to LLMM1 kernel.
|
||||||
|
tuning_sequences = [2, 3, 4, 5, 6, 7]
|
||||||
|
|
||||||
tunableop_filepath = os.path.join(
|
tunableop_filepath = os.path.join(
|
||||||
HUGGINGFACE_HUB_CACHE,
|
HUGGINGFACE_HUB_CACHE,
|
||||||
|
|
|
@ -199,6 +199,7 @@ def serve(
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
uds_path: Path,
|
uds_path: Path,
|
||||||
|
max_input_tokens: int,
|
||||||
):
|
):
|
||||||
async def serve_inner(
|
async def serve_inner(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -229,6 +230,7 @@ def serve(
|
||||||
speculate,
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
|
max_input_tokens,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from dataclasses import dataclass, field
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Optional, Set, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
from safetensors import safe_open, SafetensorError
|
from safetensors import safe_open, SafetensorError
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
@ -121,49 +120,62 @@ class Weights:
|
||||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||||
return self.get_partial_sharded(tensor_name, dim)
|
return self.get_partial_sharded(tensor_name, dim)
|
||||||
|
|
||||||
def _get_qweight(self, name: str, blocks: int):
|
def _get_qweight(self, name: str, block_sizes: Union[int, List[int]]):
|
||||||
slice_ = self._get_slice(name)
|
slice_ = self._get_slice(name)
|
||||||
total_size = slice_.get_shape()[1]
|
total_size = slice_.get_shape()[1]
|
||||||
assert (
|
block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)
|
||||||
total_size % blocks == 0
|
|
||||||
), f"Prepacked quantized matrix is not divisible by {blocks}"
|
|
||||||
single_size = total_size // blocks
|
|
||||||
world_size = self.process_group.size()
|
world_size = self.process_group.size()
|
||||||
rank = self.process_group.rank()
|
rank = self.process_group.rank()
|
||||||
|
|
||||||
assert (
|
|
||||||
single_size % world_size == 0
|
|
||||||
), f"Prepacked quantized matrix cannot be sharded across {world_size} shards"
|
|
||||||
block_size = single_size // world_size
|
|
||||||
start = rank * block_size
|
|
||||||
stop = (rank + 1) * block_size
|
|
||||||
|
|
||||||
weights = []
|
weights = []
|
||||||
for block in range(blocks):
|
block_offset = 0
|
||||||
weights.append(
|
for block_size in block_sizes:
|
||||||
slice_[:, start + block * single_size : stop + block * single_size]
|
assert (
|
||||||
)
|
block_size % world_size == 0
|
||||||
|
), f"Prepacked qkv cannot be sharded across {world_size} shards"
|
||||||
|
shard_block_size = block_size // world_size
|
||||||
|
start = rank * shard_block_size
|
||||||
|
stop = (rank + 1) * shard_block_size
|
||||||
|
weights.append(slice_[:, block_offset + start : block_offset + stop])
|
||||||
|
block_offset += block_size
|
||||||
|
|
||||||
weight = torch.cat(weights, dim=1)
|
weight = torch.cat(weights, dim=1)
|
||||||
weight = weight.to(device=self.device)
|
weight = weight.to(device=self.device)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
|
def get_weights_col_packed_qkv(
|
||||||
return self.get_weights_col_packed(prefix, quantize, 3)
|
self,
|
||||||
|
prefix: str,
|
||||||
|
quantize: str,
|
||||||
|
num_heads: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
):
|
||||||
|
return self.get_weights_col_packed(
|
||||||
|
prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads]
|
||||||
|
)
|
||||||
|
|
||||||
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
|
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
|
||||||
return self.get_weights_col_packed(prefix, quantize, 2)
|
return self.get_weights_col_packed(prefix, quantize, 2)
|
||||||
|
|
||||||
def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int):
|
def get_weights_col_packed(
|
||||||
|
self, prefix: str, quantize: str, block_sizes: Union[int, List[int]]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
||||||
already alternating Q,K,V within the main tensor
|
already alternating Q,K,V within the main tensor.
|
||||||
|
|
||||||
|
The columns are split in equally sized blocks when blocks is an `int`, or
|
||||||
|
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
|
||||||
|
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
|
||||||
|
convenient for e.g. splitting QKV without knowing the storage details of
|
||||||
|
quantized weights.
|
||||||
"""
|
"""
|
||||||
if quantize in ["gptq", "awq"]:
|
if quantize in ["gptq", "awq"]:
|
||||||
from text_generation_server.layers.gptq import GPTQWeight
|
from text_generation_server.layers.gptq import GPTQWeight
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = self._get_qweight(f"{prefix}.qweight", blocks)
|
qweight = self._get_qweight(f"{prefix}.qweight", block_sizes)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
||||||
|
@ -171,8 +183,8 @@ class Weights:
|
||||||
|
|
||||||
bits, groupsize, _, quant_method = self._get_gptq_params()
|
bits, groupsize, _, quant_method = self._get_gptq_params()
|
||||||
|
|
||||||
qzeros = self._get_qweight(f"{prefix}.qzeros", blocks)
|
qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes)
|
||||||
scales = self._get_qweight(f"{prefix}.scales", blocks)
|
scales = self._get_qweight(f"{prefix}.scales", block_sizes)
|
||||||
scales = scales.to(dtype=self.dtype)
|
scales = scales.to(dtype=self.dtype)
|
||||||
|
|
||||||
if quantize == "gptq" and quant_method == "gptq":
|
if quantize == "gptq" and quant_method == "gptq":
|
||||||
|
@ -205,27 +217,31 @@ class Weights:
|
||||||
elif quantize == "marlin":
|
elif quantize == "marlin":
|
||||||
from text_generation_server.layers.marlin import MarlinWeight
|
from text_generation_server.layers.marlin import MarlinWeight
|
||||||
|
|
||||||
B = self._get_qweight(f"{prefix}.B", blocks)
|
B = self._get_qweight(f"{prefix}.B", block_sizes)
|
||||||
s = self._get_qweight(f"{prefix}.s", blocks)
|
s = self._get_qweight(f"{prefix}.s", block_sizes)
|
||||||
weight = MarlinWeight(B=B, s=s)
|
weight = MarlinWeight(B=B, s=s)
|
||||||
else:
|
else:
|
||||||
slice_ = self._get_slice(f"{prefix}.weight")
|
slice_ = self._get_slice(f"{prefix}.weight")
|
||||||
total_size = slice_.get_shape()[0]
|
total_size = slice_.get_shape()[0]
|
||||||
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
|
block_sizes = _blocks_to_block_sizes(
|
||||||
single_size = total_size // blocks
|
total_size=total_size, blocks=block_sizes
|
||||||
|
)
|
||||||
|
|
||||||
world_size = self.process_group.size()
|
world_size = self.process_group.size()
|
||||||
rank = self.process_group.rank()
|
rank = self.process_group.rank()
|
||||||
|
|
||||||
assert (
|
|
||||||
single_size % world_size == 0
|
|
||||||
), f"Prepacked qkv cannot be sharded across {world_size} shards"
|
|
||||||
block_size = single_size // world_size
|
|
||||||
start = rank * block_size
|
|
||||||
stop = (rank + 1) * block_size
|
|
||||||
tensors = []
|
tensors = []
|
||||||
for i in range(blocks):
|
block_offset = 0
|
||||||
tensor = slice_[start + i * single_size : stop + i * single_size]
|
for block_size in block_sizes:
|
||||||
|
assert (
|
||||||
|
block_size % world_size == 0
|
||||||
|
), f"Prepacked weights cannot be sharded across {world_size} shards"
|
||||||
|
shard_block_size = block_size // world_size
|
||||||
|
start = rank * shard_block_size
|
||||||
|
stop = (rank + 1) * shard_block_size
|
||||||
|
tensor = slice_[block_offset + start : block_offset + stop]
|
||||||
tensors.append(tensor)
|
tensors.append(tensor)
|
||||||
|
block_offset += block_size
|
||||||
weight = torch.cat(tensors, dim=0)
|
weight = torch.cat(tensors, dim=0)
|
||||||
weight = weight.to(device=self.device)
|
weight = weight.to(device=self.device)
|
||||||
weight = weight.to(dtype=self.dtype)
|
weight = weight.to(dtype=self.dtype)
|
||||||
|
@ -593,3 +609,31 @@ class Weights:
|
||||||
self.quant_method = "awq"
|
self.quant_method = "awq"
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
||||||
|
"""
|
||||||
|
Convert block count or proportions to block sizes.
|
||||||
|
|
||||||
|
This function accepts
|
||||||
|
|
||||||
|
- The number of blocks (int), in which case the block size is
|
||||||
|
total_size//blocks; or
|
||||||
|
- A list of block sizes (List[int]).
|
||||||
|
|
||||||
|
In the latter case, if sum(blocks) < total_size, the ratios between
|
||||||
|
the block sizes will be preserved. For instance, if blocks is
|
||||||
|
[2, 1, 1] and total_size is 1024, the returned block sizes are
|
||||||
|
[512, 256, 256].
|
||||||
|
"""
|
||||||
|
if isinstance(blocks, list):
|
||||||
|
total_blocks = sum(blocks)
|
||||||
|
assert (
|
||||||
|
total_size % total_blocks == 0
|
||||||
|
), f"Cannot split {total_size} in proportional blocks: {blocks}"
|
||||||
|
part_size = total_size // total_blocks
|
||||||
|
return [part_size * block for block in blocks]
|
||||||
|
else:
|
||||||
|
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
|
||||||
|
single_size = total_size // blocks
|
||||||
|
return [single_size] * blocks
|
||||||
|
|
Loading…
Reference in New Issue