feat: add cuda memory fraction (#659)

Close #673
This commit is contained in:
OlivierDehaene 2023-07-24 11:43:58 +02:00 committed by GitHub
parent 1da642bd0e
commit 73a4d65d26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 127 additions and 60 deletions

View File

@ -48,8 +48,12 @@ async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_gptq_load(flash_llama_gptq, generate_load, response_snapshot): async def test_flash_llama_gptq_load(
responses = await generate_load(flash_llama_gptq, "Test request", max_new_tokens=10, n=4) flash_llama_gptq, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_gptq, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4 assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]) assert all([r.generated_text == responses[0].generated_text for r in responses])

View File

@ -17,7 +17,9 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
@pytest.mark.private @pytest.mark.private
async def test_flash_starcoder_gptq(flash_starcoder_gptq, response_snapshot): async def test_flash_starcoder_gptq(flash_starcoder_gptq, response_snapshot):
response = await flash_starcoder_gptq.generate( response = await flash_starcoder_gptq.generate(
"def geometric_mean(L: List[float]):", max_new_tokens=20, decoder_input_details=True, "def geometric_mean(L: List[float]):",
max_new_tokens=20,
decoder_input_details=True,
) )
assert response.details.generated_tokens == 20 assert response.details.generated_tokens == 20
assert response == response_snapshot assert response == response_snapshot
@ -25,7 +27,9 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_starcoder_gptq_default_params(flash_starcoder_gptq, response_snapshot): async def test_flash_starcoder_gptq_default_params(
flash_starcoder_gptq, response_snapshot
):
response = await flash_starcoder_gptq.generate( response = await flash_starcoder_gptq.generate(
"def geometric_mean(L: List[float]):", "def geometric_mean(L: List[float]):",
max_new_tokens=20, max_new_tokens=20,
@ -40,10 +44,17 @@ async def test_flash_starcoder_gptq_default_params(flash_starcoder_gptq, respons
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_starcoder_gptq_load(flash_starcoder_gptq, generate_load, response_snapshot): async def test_flash_starcoder_gptq_load(
responses = await generate_load(flash_starcoder_gptq, "def geometric_mean(L: List[float]):", max_new_tokens=10, n=4) flash_starcoder_gptq, generate_load, response_snapshot
):
responses = await generate_load(
flash_starcoder_gptq,
"def geometric_mean(L: List[float]):",
max_new_tokens=10,
n=4,
)
assert len(responses) == 4 assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]) assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot assert responses == response_snapshot

View File

@ -245,6 +245,11 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
disable_custom_kernels: bool, disable_custom_kernels: bool,
/// Limit the CUDA available memory.
/// The allowed value equals the total visible memory multiplied by cuda-memory-fraction.
#[clap(default_value = "1.0", long, env)]
cuda_memory_fraction: f32,
/// Outputs the logs in JSON format (useful for telemetry) /// Outputs the logs in JSON format (useful for telemetry)
#[clap(long, env)] #[clap(long, env)]
json_output: bool, json_output: bool,
@ -299,6 +304,7 @@ fn shard_manager(
disable_custom_kernels: bool, disable_custom_kernels: bool,
watermark_gamma: Option<f32>, watermark_gamma: Option<f32>,
watermark_delta: Option<f32>, watermark_delta: Option<f32>,
cuda_memory_fraction: f32,
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>, status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
@ -368,6 +374,12 @@ fn shard_manager(
envs.push(("MASTER_PORT".into(), master_port.to_string().into())); envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into())); envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
// CUDA memory fraction
envs.push((
"CUDA_MEMORY_FRACTION".into(),
cuda_memory_fraction.to_string().into(),
));
// Safetensors load fast // Safetensors load fast
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into())); envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
@ -771,6 +783,7 @@ fn spawn_shards(
let disable_custom_kernels = args.disable_custom_kernels; let disable_custom_kernels = args.disable_custom_kernels;
let watermark_gamma = args.watermark_gamma; let watermark_gamma = args.watermark_gamma;
let watermark_delta = args.watermark_delta; let watermark_delta = args.watermark_delta;
let cuda_memory_fraction = args.cuda_memory_fraction;
thread::spawn(move || { thread::spawn(move || {
shard_manager( shard_manager(
model_id, model_id,
@ -788,6 +801,7 @@ fn spawn_shards(
disable_custom_kernels, disable_custom_kernels,
watermark_gamma, watermark_gamma,
watermark_delta, watermark_delta,
cuda_memory_fraction,
otlp_endpoint, otlp_endpoint,
status_sender, status_sender,
shutdown, shutdown,

View File

@ -101,8 +101,12 @@ impl ShardedClient {
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens))) .map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens)))
.collect(); .collect();
// all shards return the same message // Take the minimum value
join_all(futures).await.pop().unwrap() let results = join_all(futures)
.await
.into_iter()
.collect::<Result<Vec<Option<u32>>>>()?;
Ok(results.into_iter().flatten().min())
} }
/// Generate one token for each request in the given batch /// Generate one token for each request in the given batch

View File

@ -11,7 +11,7 @@ setup(
"exllama_kernels/cuda_buffers.cu", "exllama_kernels/cuda_buffers.cu",
"exllama_kernels/cuda_func/column_remap.cu", "exllama_kernels/cuda_func/column_remap.cu",
"exllama_kernels/cuda_func/q4_matmul.cu", "exllama_kernels/cuda_func/q4_matmul.cu",
"exllama_kernels/cuda_func/q4_matrix.cu" "exllama_kernels/cuda_func/q4_matrix.cu",
], ],
) )
], ],

View File

@ -20,6 +20,7 @@ from text_generation_server.utils.layers import (
) )
from safetensors import SafetensorError from safetensors import SafetensorError
def load_multi_mqa( def load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
): ):
@ -78,6 +79,7 @@ def _load_multi_mqa_gptq(
bits, groupsize = weights._get_gptq_qparams() bits, groupsize = weights._get_gptq_qparams()
from text_generation_server.utils.layers import HAS_EXLLAMA from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = HAS_EXLLAMA use_exllama = HAS_EXLLAMA
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)

View File

@ -19,6 +19,7 @@ from text_generation_server.models.types import (
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -738,7 +739,12 @@ class FlashCausalLM(Model):
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
free_memory, _ = torch.cuda.mem_get_info(self.device) total_free_memory, _ = torch.cuda.mem_get_info(self.device)
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
free_memory = max(
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
)
num_blocks = ( num_blocks = (
int(free_memory // total_cache_size) int(free_memory // total_cache_size)

View File

@ -10,6 +10,7 @@ from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)
class Model(ABC): class Model(ABC):
def __init__( def __init__(
self, self,
@ -21,9 +22,6 @@ class Model(ABC):
rank: int = 0, rank: int = 0,
world_size: int = 1, world_size: int = 1,
): ):
if torch.cuda.is_available():
torch.cuda.set_per_process_memory_fraction(1.0)
self.model = model.eval() self.model = model.eval()
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids) self.all_special_ids = set(tokenizer.all_special_ids)

View File

@ -16,7 +16,6 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, model: Model, cache: Cache, server_urls: List[str]): def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
self.cache = cache self.cache = cache
@ -146,7 +145,10 @@ def serve(
# When using GPTQ, Exllama kernels need some global kernels # When using GPTQ, Exllama kernels need some global kernels
# For which we have the finale shapes only after the model has loaded # For which we have the finale shapes only after the model has loaded
# This will allocate those buffers. # This will allocate those buffers.
from text_generation_server.utils.gptq.exllama import create_exllama_buffers from text_generation_server.utils.gptq.exllama import (
create_exllama_buffers,
)
create_exllama_buffers() create_exllama_buffers()
except ImportError: except ImportError:
pass pass

View File

@ -4,6 +4,13 @@ import torch
from datetime import timedelta from datetime import timedelta
from loguru import logger from loguru import logger
# Tensor Parallelism settings
RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
# CUDA memory fraction
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
class FakeBarrier: class FakeBarrier:
def wait(self): def wait(self):
@ -37,16 +44,14 @@ class FakeGroup:
def initialize_torch_distributed(): def initialize_torch_distributed():
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if torch.cuda.is_available(): if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL from torch.distributed import ProcessGroupNCCL
# Set the device id. # Set the device id.
assert world_size <= torch.cuda.device_count(), "Each process is one gpu" assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu"
device = rank % torch.cuda.device_count() device = RANK % torch.cuda.device_count()
torch.cuda.set_device(device) torch.cuda.set_device(device)
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
backend = "nccl" backend = "nccl"
options = ProcessGroupNCCL.Options() options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True options.is_high_priority_stream = True
@ -55,22 +60,22 @@ def initialize_torch_distributed():
backend = "gloo" backend = "gloo"
options = None options = None
if world_size == 1: if WORLD_SIZE == 1:
return FakeGroup(rank, world_size), rank, world_size return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
else: else:
if os.getenv("DEBUG", None) == "1": if os.getenv("DEBUG", None) == "1":
return FakeGroup(rank, world_size), rank, world_size return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
# Call the init process. # Call the init process.
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=backend, backend=backend,
world_size=world_size, world_size=WORLD_SIZE,
rank=rank, rank=RANK,
timeout=timedelta(seconds=60), timeout=timedelta(seconds=60),
pg_options=options, pg_options=options,
) )
else: else:
logger.warning("torch.distributed is already initialized.") logger.warning("torch.distributed is already initialized.")
return torch.distributed.group.WORLD, rank, world_size return torch.distributed.group.WORLD, RANK, WORLD_SIZE

View File

@ -1,28 +1,28 @@
import torch import torch
from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device = "meta") none_tensor = torch.empty((1, 1), device="meta")
def ext_make_q4(qweight, qzeros, scales, g_idx, device): def ext_make_q4(qweight, qzeros, scales, g_idx, device):
"""Construct Q4Matrix, return handle""" """Construct Q4Matrix, return handle"""
return make_q4(qweight, return make_q4(
qzeros, qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device
scales, )
g_idx if g_idx is not None else none_tensor,
device)
def ext_q4_matmul(x, q4, q4_width): def ext_q4_matmul(x, q4, q4_width):
"""Matrix multiplication, returns x @ q4""" """Matrix multiplication, returns x @ q4"""
outshape = x.shape[:-1] + (q4_width,) outshape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device) output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device)
q4_matmul(x, q4, output) q4_matmul(x, q4, output)
return output.view(outshape) return output.view(outshape)
MAX_DQ = 1 MAX_DQ = 1
MAX_INNER = 1 MAX_INNER = 1
ACT_ORDER = False ACT_ORDER = False
@ -31,9 +31,10 @@ DEVICE = None
TEMP_STATE = None TEMP_STATE = None
TEMP_DQ = None TEMP_DQ = None
def create_exllama_buffers(): def create_exllama_buffers():
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
if ACT_ORDER: if ACT_ORDER:
# TODO: this should be set to rust side `max_total_tokens`, but TGI # TODO: this should be set to rust side `max_total_tokens`, but TGI
# does not offer an API to expose this variable to python, as this variable # does not offer an API to expose this variable to python, as this variable
@ -45,7 +46,9 @@ def create_exllama_buffers():
max_total_tokens = 1 max_total_tokens = 1
# This temp_state buffer is required to reorder X in the act-order case. # This temp_state buffer is required to reorder X in the act-order case.
temp_state = torch.zeros((max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE) temp_state = torch.zeros(
(max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE
)
temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE) temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
@ -56,10 +59,12 @@ def create_exllama_buffers():
matmul_no_half2 = False matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
TEMP_STATE, TEMP_DQ = temp_state, temp_dq TEMP_STATE, TEMP_DQ = temp_state, temp_dq
class Ex4bitLinear: class Ex4bitLinear:
"""Linear layer implementation with per-group 4-bit quantization of the weights""" """Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
assert bits == 4 assert bits == 4
@ -70,20 +75,24 @@ class Ex4bitLinear:
self.scales = scales self.scales = scales
self.g_idx = g_idx.cpu() if g_idx is not None else None self.g_idx = g_idx.cpu() if g_idx is not None else None
self.bias = bias if bias is not None else None self.bias = bias if bias is not None else None
if self.g_idx is not None and ((self.g_idx == 0).all() or torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32))): if self.g_idx is not None and (
(self.g_idx == 0).all()
or torch.equal(
g_idx.cpu(),
torch.tensor(
[i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32
),
)
):
self.empty_g_idx = True self.empty_g_idx = True
self.g_idx = None self.g_idx = None
assert self.device.type == "cuda" assert self.device.type == "cuda"
assert self.device.index is not None assert self.device.index is not None
self.q4 = ext_make_q4( self.q4 = ext_make_q4(
self.qweight, self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index
self.qzeros,
self.scales,
self.g_idx,
self.device.index
) )
self.height = qweight.shape[0] * 8 self.height = qweight.shape[0] * 8
@ -99,7 +108,8 @@ class Ex4bitLinear:
# Handle act-order matrix # Handle act-order matrix
if self.g_idx is not None: if self.g_idx is not None:
if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?") if self.groupsize is None:
raise ValueError("Found group index but no groupsize. What do?")
self.act_order = True self.act_order = True
else: else:
self.act_order = False self.act_order = False
@ -112,7 +122,7 @@ class Ex4bitLinear:
MAX_INNER = max(MAX_INNER, self.height, self.width) MAX_INNER = max(MAX_INNER, self.height, self.width)
ACT_ORDER = True ACT_ORDER = True
def forward(self, x): def forward(self, x):
out = ext_q4_matmul(x, self.q4, self.width) out = ext_q4_matmul(x, self.q4, self.width)

View File

@ -815,11 +815,7 @@ def load_weights_pre_hook(module_name, weights, recursive=False):
tensor = current_tensor.to(device=torch.device("cuda:0")) tensor = current_tensor.to(device=torch.device("cuda:0"))
if current_tensor.requires_grad: if current_tensor.requires_grad:
tensor = nn.Parameter(tensor) tensor = nn.Parameter(tensor)
setdeepattr( setdeepattr(module, local_param, tensor)
module,
local_param,
tensor
)
return inner return inner

View File

@ -17,9 +17,10 @@ except ImportError:
from accelerate import init_empty_weights from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear from text_generation_server.utils.gptq.quant_linear import QuantLinear
HAS_EXLLAMA = True HAS_EXLLAMA = True
if os.getenv("DISABLE_EXLLAMA") == "True": if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA=False HAS_EXLLAMA = False
try: try:
from text_generation_server.utils.gptq.exllama import Ex4bitLinear from text_generation_server.utils.gptq.exllama import Ex4bitLinear
except ImportError: except ImportError:

View File

@ -146,7 +146,16 @@ class Weights:
if self.process_group.size() > 1: if self.process_group.size() > 1:
g_idx = self.get_tensor(f"{prefix}.g_idx") g_idx = self.get_tensor(f"{prefix}.g_idx")
if g_idx is not None: if g_idx is not None:
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all(): if (
not torch.equal(
g_idx.cpu(),
torch.tensor(
[i // groupsize for i in range(g_idx.shape[0])],
dtype=torch.int32,
),
)
and not (g_idx == 0).all()
):
# Exllama implementation does not support row tensor parallelism with act-order, as # Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs # it would require to reorder input activations that are split unto several GPUs
use_exllama = False use_exllama = False
@ -154,18 +163,21 @@ class Weights:
try: try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0) qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError: except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
from text_generation_server.utils.layers import HAS_EXLLAMA from text_generation_server.utils.layers import HAS_EXLLAMA
if use_exllama: if use_exllama:
if not HAS_EXLLAMA: if not HAS_EXLLAMA:
logger.warning("Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True") logger.warning(
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True"
)
use_exllama = False use_exllama = False
else: else:
logger.info("Using exllama kernels") logger.info("Using exllama kernels")
if use_exllama: if use_exllama:
if groupsize >= 0: if groupsize >= 0:
# Exllama reorders the weights in advance and the activations on the fly, thus # Exllama reorders the weights in advance and the activations on the fly, thus
@ -173,7 +185,9 @@ class Weights:
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0) scales = self.get_sharded(f"{prefix}.scales", dim=0)
else: else:
raise RuntimeError("Using exllama GPTQ kernel with groupsize<1 is not supported") raise RuntimeError(
"Using exllama GPTQ kernel with groupsize<1 is not supported"
)
# qzeros = self.get_tensor(f"{prefix}.qzeros") # qzeros = self.get_tensor(f"{prefix}.qzeros")
# scales = self.get_tensor(f"{prefix}.scales") # scales = self.get_tensor(f"{prefix}.scales")