feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation. - Need to expose the quantization scripts (either included here or add doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa) - Make sure GPTQ works for multiple models (priority to Falcon). Currently it means that every place we use `get_{tensor|sharded}` to check for quantization. My idea is to reintegrate as much as possible into `utils/layer.py` by expanding `load_multi` to be a bit more generic. This might require some thinking, but ultimately the `qweight,qzeros,scales,g_idx` should be in a single place, and independant of bias presence. # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal> Co-authored-by: OlivierDehaene <olivier@huggingface.co>
This commit is contained in:
parent
bd3a9d8e85
commit
aefde28b45
|
@ -159,6 +159,11 @@ COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bi
|
||||||
# Install launcher
|
# Install launcher
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
|
build-essential \
|
||||||
|
g++ \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# AWS Sagemaker compatbile image
|
# AWS Sagemaker compatbile image
|
||||||
FROM base as sagemaker
|
FROM base as sagemaker
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,21 +1,21 @@
|
||||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0"
|
backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
bitsandbytes==0.38.1 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
certifi==2023.5.7 ; python_version >= "3.9" and python_version < "4.0"
|
certifi==2023.5.7 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0"
|
charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
click==8.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
click==8.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" or python_version >= "3.9" and python_version < "4.0" and platform_system == "Windows"
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.13 ; python_version >= "3.9" and python_version < "4.0"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
filelock==3.12.0 ; python_version >= "3.9" and python_version < "4.0"
|
filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
fsspec==2023.5.0 ; python_version >= "3.9" and python_version < "4.0"
|
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
googleapis-common-protos==1.59.0 ; python_version >= "3.9" and python_version < "4.0"
|
googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0"
|
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
grpcio-reflection==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
|
grpcio-reflection==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
grpcio-status==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
|
grpcio-status==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
grpcio==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
|
grpcio==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0"
|
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
idna==3.4 ; python_version >= "3.9" and python_version < "4"
|
idna==3.4 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
|
numpy==1.24.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
|
@ -26,17 +26,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
|
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
packaging==23.1 ; python_version >= "3.9" and python_version < "4.0"
|
packaging==23.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
protobuf==4.23.1 ; python_version >= "3.9" and python_version < "4.0"
|
protobuf==4.23.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
pyyaml==6.0 ; python_version >= "3.9" and python_version < "4.0"
|
pyyaml==6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
|
regex==2023.6.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
|
requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
|
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0"
|
setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
|
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
|
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
|
transformers==4.30.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
typing-extensions==4.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
typing-extensions==4.6.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
urllib3==2.0.2 ; python_version >= "3.9" and python_version < "4.0"
|
urllib3==2.0.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32"
|
||||||
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
|
|
|
@ -151,5 +151,37 @@ def download_weights(
|
||||||
utils.convert_files(local_pt_files, local_st_files)
|
utils.convert_files(local_pt_files, local_st_files)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def quantize(
|
||||||
|
model_id: str,
|
||||||
|
output_dir: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
logger_level: str = "INFO",
|
||||||
|
json_output: bool = False,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
upload_to_model_id: Optional[str] = None,
|
||||||
|
percdamp: float = 0.01,
|
||||||
|
act_order: bool = False,
|
||||||
|
):
|
||||||
|
download_weights(
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
logger_level=logger_level,
|
||||||
|
json_output=json_output,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.gptq.quantize import quantize
|
||||||
|
|
||||||
|
quantize(
|
||||||
|
model_id=model_id,
|
||||||
|
bits=4,
|
||||||
|
groupsize=128,
|
||||||
|
output_dir=output_dir,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
upload_to_model_id=upload_to_model_id,
|
||||||
|
percdamp=percdamp,
|
||||||
|
act_order=act_order,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app()
|
app()
|
||||||
|
|
|
@ -246,6 +246,10 @@ def get_model(
|
||||||
|
|
||||||
if sharded:
|
if sharded:
|
||||||
raise ValueError("sharded is not supported for AutoModel")
|
raise ValueError("sharded is not supported for AutoModel")
|
||||||
|
if quantize == "gptq":
|
||||||
|
raise ValueError(
|
||||||
|
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
|
)
|
||||||
|
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
|
|
|
@ -42,7 +42,8 @@ from text_generation_server.utils.layers import (
|
||||||
|
|
||||||
|
|
||||||
def load_row(config, prefix: str, weights, bias: bool):
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||||
|
|
||||||
if bias and weights.process_group.rank() == 0:
|
if bias and weights.process_group.rank() == 0:
|
||||||
# Rank is only on the first rank process
|
# Rank is only on the first rank process
|
||||||
bias = weights.get_tensor(f"{prefix}.bias")
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
@ -57,19 +58,21 @@ def load_row(config, prefix: str, weights, bias: bool):
|
||||||
|
|
||||||
|
|
||||||
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
weight = weights.get_multi_weights_col([prefix], quantize=config.quantize, dim=0)
|
||||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
if isinstance(weight, torch.Tensor):
|
||||||
|
# Only on non quantized versions
|
||||||
weight = (
|
weight = (
|
||||||
weight.view(
|
weight.view(
|
||||||
num_heads,
|
num_heads,
|
||||||
3,
|
3,
|
||||||
head_size,
|
head_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
)
|
||||||
|
.permute(1, 0, 2, 3)
|
||||||
|
.reshape(-1, hidden_size)
|
||||||
)
|
)
|
||||||
.permute(1, 0, 2, 3)
|
|
||||||
.reshape(-1, hidden_size)
|
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||||
)
|
|
||||||
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
|
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
|
||||||
|
|
||||||
linear = get_linear(weight, bias, config.quantize)
|
linear = get_linear(weight, bias, config.quantize)
|
||||||
|
|
|
@ -21,7 +21,8 @@ from text_generation_server.utils.layers import (
|
||||||
|
|
||||||
|
|
||||||
def load_row(config, prefix: str, weights, bias: bool):
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||||
|
|
||||||
if bias and weights.process_group.rank() == 0:
|
if bias and weights.process_group.rank() == 0:
|
||||||
# Rank is only on the first rank process
|
# Rank is only on the first rank process
|
||||||
bias = weights.get_tensor(f"{prefix}.bias")
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
|
|
@ -21,6 +21,81 @@ from text_generation_server.utils.layers import (
|
||||||
def load_multi_mqa(
|
def load_multi_mqa(
|
||||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||||
):
|
):
|
||||||
|
|
||||||
|
if config.quantize == "gptq":
|
||||||
|
return _load_multi_mqa_gptq(
|
||||||
|
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return _load_multi_mqa(
|
||||||
|
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_multi_mqa_gptq(
|
||||||
|
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||||
|
):
|
||||||
|
if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose:
|
||||||
|
world_size = weights.process_group.size()
|
||||||
|
rank = weights.process_group.rank()
|
||||||
|
|
||||||
|
slice_ = weights._get_slice(f"{prefix}.c_attn.qweight")
|
||||||
|
shape = slice_.get_shape()
|
||||||
|
block_size = (shape[1] - 2 * head_size) // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
assert (shape[1] - 2 * head_size) % world_size == 0
|
||||||
|
q_tensor = slice_[:, start:stop]
|
||||||
|
kv_tensor = slice_[:, -2 * head_size :]
|
||||||
|
qweight = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||||
|
|
||||||
|
slice_ = weights._get_slice(f"{prefix}.c_attn.scales")
|
||||||
|
shape = slice_.get_shape()
|
||||||
|
block_size = (shape[1] - 2 * head_size) // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
assert (shape[1] - 2 * head_size) % world_size == 0
|
||||||
|
q_tensor = slice_[:, start:stop]
|
||||||
|
kv_tensor = slice_[:, -2 * head_size :]
|
||||||
|
scales = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||||
|
|
||||||
|
slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros")
|
||||||
|
shape = slice_.get_shape()
|
||||||
|
block_size = (shape[1] - (2 * head_size) * 4 // 32) // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
assert 2 * head_size % (32 // 4) == 0
|
||||||
|
q_tensor = slice_[:, start:stop]
|
||||||
|
kv_tensor = slice_[:, -2 * head_size * 4 // 32 :]
|
||||||
|
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||||
|
|
||||||
|
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||||
|
bits = weights.get_tensor("gptq_bits").item()
|
||||||
|
groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||||
|
|
||||||
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
||||||
|
shape = slice_.get_shape()
|
||||||
|
block_size = (shape[0] - 2 * head_size) // world_size
|
||||||
|
assert (shape[0] - 2 * head_size) % world_size == 0
|
||||||
|
q_tensor = slice_[start:stop]
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
q_tensor = slice_[start:stop]
|
||||||
|
kv_tensor = slice_[-2 * head_size :]
|
||||||
|
bias = torch.cat([q_tensor, kv_tensor], dim=0)
|
||||||
|
|
||||||
|
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Gptq loading with santacoder is not implemented")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_multi_mqa(
|
||||||
|
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||||
|
):
|
||||||
|
|
||||||
if any("c_attn" in k for k in weights.routing.keys()):
|
if any("c_attn" in k for k in weights.routing.keys()):
|
||||||
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
|
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
|
||||||
shape = slice_.get_shape()
|
shape = slice_.get_shape()
|
||||||
|
@ -92,7 +167,9 @@ def load_col(config, prefix: str, weights, bias: bool):
|
||||||
if config.transpose:
|
if config.transpose:
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
||||||
else:
|
else:
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
weight = weights.get_multi_weights_col(
|
||||||
|
[prefix], quantize=config.quantize, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||||
|
@ -105,7 +182,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
||||||
if config.transpose:
|
if config.transpose:
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
||||||
else:
|
else:
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||||
|
|
||||||
if bias and weights.process_group.rank() == 0:
|
if bias and weights.process_group.rank() == 0:
|
||||||
# Rank is only on the first rank process
|
# Rank is only on the first rank process
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch.distributed
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
from transformers.models.llama import LlamaTokenizer
|
from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
|
@ -34,13 +34,22 @@ class FlashLlama(FlashCausalLM):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||||
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(
|
try:
|
||||||
model_id,
|
tokenizer = LlamaTokenizer.from_pretrained(
|
||||||
revision=revision,
|
model_id,
|
||||||
padding_side="left",
|
revision=revision,
|
||||||
truncation_side="left",
|
padding_side="left",
|
||||||
trust_remote_code=trust_remote_code,
|
truncation_side="left",
|
||||||
)
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
|
|
@ -0,0 +1,261 @@
|
||||||
|
# https://github.com/fpgaminer/GPTQ-triton
|
||||||
|
"""
|
||||||
|
Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import builtins
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import triton
|
||||||
|
|
||||||
|
|
||||||
|
class Autotuner(triton.KernelInterface):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fn,
|
||||||
|
arg_names,
|
||||||
|
configs,
|
||||||
|
key,
|
||||||
|
reset_to_zero,
|
||||||
|
prune_configs_by: Dict = None,
|
||||||
|
nearest_power_of_two: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||||
|
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||||
|
'top_k': number of configs to bench
|
||||||
|
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||||
|
'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results
|
||||||
|
"""
|
||||||
|
if not configs:
|
||||||
|
self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
|
||||||
|
else:
|
||||||
|
self.configs = configs
|
||||||
|
self.key_idx = [arg_names.index(k) for k in key]
|
||||||
|
self.nearest_power_of_two = nearest_power_of_two
|
||||||
|
self.cache = {}
|
||||||
|
# hook to reset all required tensor to zeros before relaunching a kernel
|
||||||
|
self.hook = lambda args: 0
|
||||||
|
if reset_to_zero is not None:
|
||||||
|
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
||||||
|
|
||||||
|
def _hook(args):
|
||||||
|
for i in self.reset_idx:
|
||||||
|
args[i].zero_()
|
||||||
|
|
||||||
|
self.hook = _hook
|
||||||
|
self.arg_names = arg_names
|
||||||
|
# prune configs
|
||||||
|
if prune_configs_by:
|
||||||
|
perf_model, top_k = (
|
||||||
|
prune_configs_by["perf_model"],
|
||||||
|
prune_configs_by["top_k"],
|
||||||
|
)
|
||||||
|
if "early_config_prune" in prune_configs_by:
|
||||||
|
early_config_prune = prune_configs_by["early_config_prune"]
|
||||||
|
else:
|
||||||
|
perf_model, top_k, early_config_prune = None, None, None
|
||||||
|
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||||
|
self.early_config_prune = early_config_prune
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def _bench(self, *args, config, **meta):
|
||||||
|
# check for conflicts, i.e. meta-parameters both provided
|
||||||
|
# as kwargs and by the autotuner
|
||||||
|
conflicts = meta.keys() & config.kwargs.keys()
|
||||||
|
if conflicts:
|
||||||
|
raise ValueError(
|
||||||
|
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||||
|
" Make sure that you don't re-define auto-tuned symbols."
|
||||||
|
)
|
||||||
|
# augment meta-parameters with tunable ones
|
||||||
|
current = dict(meta, **config.kwargs)
|
||||||
|
|
||||||
|
def kernel_call():
|
||||||
|
if config.pre_hook:
|
||||||
|
config.pre_hook(self.nargs)
|
||||||
|
self.hook(args)
|
||||||
|
self.fn.run(
|
||||||
|
*args,
|
||||||
|
num_warps=config.num_warps,
|
||||||
|
num_stages=config.num_stages,
|
||||||
|
**current,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
|
||||||
|
# PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
|
||||||
|
return triton.testing.do_bench(
|
||||||
|
kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40
|
||||||
|
)
|
||||||
|
except triton.compiler.OutOfResources:
|
||||||
|
return (float("inf"), float("inf"), float("inf"))
|
||||||
|
|
||||||
|
def run(self, *args, **kwargs):
|
||||||
|
self.nargs = dict(zip(self.arg_names, args))
|
||||||
|
if len(self.configs) > 1:
|
||||||
|
key = tuple(args[i] for i in self.key_idx)
|
||||||
|
|
||||||
|
# This reduces the amount of autotuning by rounding the keys to the nearest power of two
|
||||||
|
# In my testing this gives decent results, and greatly reduces the amount of tuning required
|
||||||
|
if self.nearest_power_of_two:
|
||||||
|
key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
|
||||||
|
|
||||||
|
if key not in self.cache:
|
||||||
|
# prune configs
|
||||||
|
pruned_configs = self.prune_configs(kwargs)
|
||||||
|
bench_start = time.time()
|
||||||
|
timings = {
|
||||||
|
config: self._bench(*args, config=config, **kwargs)
|
||||||
|
for config in pruned_configs
|
||||||
|
}
|
||||||
|
bench_end = time.time()
|
||||||
|
self.bench_time = bench_end - bench_start
|
||||||
|
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||||
|
self.hook(args)
|
||||||
|
self.configs_timings = timings
|
||||||
|
config = self.cache[key]
|
||||||
|
else:
|
||||||
|
config = self.configs[0]
|
||||||
|
self.best_config = config
|
||||||
|
if config.pre_hook is not None:
|
||||||
|
config.pre_hook(self.nargs)
|
||||||
|
return self.fn.run(
|
||||||
|
*args,
|
||||||
|
num_warps=config.num_warps,
|
||||||
|
num_stages=config.num_stages,
|
||||||
|
**kwargs,
|
||||||
|
**config.kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prune_configs(self, kwargs):
|
||||||
|
pruned_configs = self.configs
|
||||||
|
if self.early_config_prune:
|
||||||
|
pruned_configs = self.early_config_prune(self.configs, self.nargs)
|
||||||
|
if self.perf_model:
|
||||||
|
top_k = self.configs_top_k
|
||||||
|
if isinstance(top_k, float) and top_k <= 1.0:
|
||||||
|
top_k = int(len(self.configs) * top_k)
|
||||||
|
if len(pruned_configs) > top_k:
|
||||||
|
est_timing = {
|
||||||
|
config: self.perf_model(
|
||||||
|
**self.nargs,
|
||||||
|
**kwargs,
|
||||||
|
**config.kwargs,
|
||||||
|
num_stages=config.num_stages,
|
||||||
|
num_warps=config.num_warps,
|
||||||
|
)
|
||||||
|
for config in pruned_configs
|
||||||
|
}
|
||||||
|
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[
|
||||||
|
:top_k
|
||||||
|
]
|
||||||
|
return pruned_configs
|
||||||
|
|
||||||
|
def warmup(self, *args, **kwargs):
|
||||||
|
self.nargs = dict(zip(self.arg_names, args))
|
||||||
|
for config in self.prune_configs(kwargs):
|
||||||
|
self.fn.warmup(
|
||||||
|
*args,
|
||||||
|
num_warps=config.num_warps,
|
||||||
|
num_stages=config.num_stages,
|
||||||
|
**kwargs,
|
||||||
|
**config.kwargs,
|
||||||
|
)
|
||||||
|
self.nargs = None
|
||||||
|
|
||||||
|
|
||||||
|
def autotune(
|
||||||
|
configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||||
|
.. highlight:: python
|
||||||
|
.. code-block:: python
|
||||||
|
@triton.autotune(configs=[
|
||||||
|
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
||||||
|
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
||||||
|
],
|
||||||
|
key=['x_size'] # the two above configs will be evaluated anytime
|
||||||
|
# the value of x_size changes
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def kernel(x_ptr, x_size, **META):
|
||||||
|
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||||
|
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
||||||
|
This means that whatever value the kernel updates will be updated multiple times.
|
||||||
|
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
||||||
|
reset the value of the provided tensor to `zero` before running any configuration.
|
||||||
|
:param configs: a list of :code:`triton.Config` objects
|
||||||
|
:type configs: list[triton.Config]
|
||||||
|
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||||
|
:type key: list[str]
|
||||||
|
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||||
|
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||||
|
'top_k': number of configs to bench
|
||||||
|
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
|
||||||
|
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||||
|
:type reset_to_zero: list[str]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(fn):
|
||||||
|
return Autotuner(
|
||||||
|
fn,
|
||||||
|
fn.arg_names,
|
||||||
|
configs,
|
||||||
|
key,
|
||||||
|
reset_to_zero,
|
||||||
|
prune_configs_by,
|
||||||
|
nearest_power_of_two,
|
||||||
|
)
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def matmul248_kernel_config_pruner(configs, nargs):
|
||||||
|
"""
|
||||||
|
The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
|
||||||
|
"""
|
||||||
|
m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16)
|
||||||
|
n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16)
|
||||||
|
k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16)
|
||||||
|
|
||||||
|
used = set()
|
||||||
|
for config in configs:
|
||||||
|
block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"])
|
||||||
|
block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"])
|
||||||
|
block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"])
|
||||||
|
group_size_m = config.kwargs["GROUP_SIZE_M"]
|
||||||
|
|
||||||
|
if (
|
||||||
|
block_size_m,
|
||||||
|
block_size_n,
|
||||||
|
block_size_k,
|
||||||
|
group_size_m,
|
||||||
|
config.num_stages,
|
||||||
|
config.num_warps,
|
||||||
|
) in used:
|
||||||
|
continue
|
||||||
|
|
||||||
|
used.add(
|
||||||
|
(
|
||||||
|
block_size_m,
|
||||||
|
block_size_n,
|
||||||
|
block_size_k,
|
||||||
|
group_size_m,
|
||||||
|
config.num_stages,
|
||||||
|
config.num_warps,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_SIZE_M": block_size_m,
|
||||||
|
"BLOCK_SIZE_N": block_size_n,
|
||||||
|
"BLOCK_SIZE_K": block_size_k,
|
||||||
|
"GROUP_SIZE_M": group_size_m,
|
||||||
|
},
|
||||||
|
num_stages=config.num_stages,
|
||||||
|
num_warps=config.num_warps,
|
||||||
|
)
|
|
@ -0,0 +1,359 @@
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
|
|
||||||
|
try:
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from . import custom_autotune
|
||||||
|
|
||||||
|
# code based https://github.com/fpgaminer/GPTQ-triton
|
||||||
|
@custom_autotune.autotune(
|
||||||
|
configs=[
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 32,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
},
|
||||||
|
num_stages=4,
|
||||||
|
num_warps=4,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 32,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
},
|
||||||
|
num_stages=4,
|
||||||
|
num_warps=4,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 32,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
},
|
||||||
|
num_stages=4,
|
||||||
|
num_warps=4,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 32,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
},
|
||||||
|
num_stages=4,
|
||||||
|
num_warps=4,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 32,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
},
|
||||||
|
num_stages=4,
|
||||||
|
num_warps=4,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 32,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
},
|
||||||
|
num_stages=2,
|
||||||
|
num_warps=8,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
},
|
||||||
|
num_stages=3,
|
||||||
|
num_warps=8,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
},
|
||||||
|
num_stages=2,
|
||||||
|
num_warps=4,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
key=["M", "N", "K"],
|
||||||
|
nearest_power_of_two=True,
|
||||||
|
prune_configs_by={
|
||||||
|
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
|
||||||
|
"perf_model": None,
|
||||||
|
"top_k": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def matmul_248_kernel(
|
||||||
|
a_ptr,
|
||||||
|
b_ptr,
|
||||||
|
c_ptr,
|
||||||
|
scales_ptr,
|
||||||
|
zeros_ptr,
|
||||||
|
g_ptr,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
bits,
|
||||||
|
maxq,
|
||||||
|
stride_am,
|
||||||
|
stride_ak,
|
||||||
|
stride_bk,
|
||||||
|
stride_bn,
|
||||||
|
stride_cm,
|
||||||
|
stride_cn,
|
||||||
|
stride_scales,
|
||||||
|
stride_zeros,
|
||||||
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
|
GROUP_SIZE_M: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute the matrix multiplication C = A x B.
|
||||||
|
A is of shape (M, K) float16
|
||||||
|
B is of shape (K//8, N) int32
|
||||||
|
C is of shape (M, N) float16
|
||||||
|
scales is of shape (G, N) float16
|
||||||
|
zeros is of shape (G, N) float16
|
||||||
|
g_ptr is of shape (K) int32
|
||||||
|
"""
|
||||||
|
infearure_per_bits = 32 // bits
|
||||||
|
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||||
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||||
|
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||||
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||||
|
group_id = pid // num_pid_in_group
|
||||||
|
first_pid_m = group_id * GROUP_SIZE_M
|
||||||
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||||
|
pid_m = first_pid_m + (pid % group_size_m)
|
||||||
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||||
|
|
||||||
|
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
|
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
a_ptrs = a_ptr + (
|
||||||
|
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
|
||||||
|
) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||||
|
a_mask = offs_am[:, None] < M
|
||||||
|
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||||
|
b_ptrs = b_ptr + (
|
||||||
|
(offs_k[:, None] // infearure_per_bits) * stride_bk
|
||||||
|
+ offs_bn[None, :] * stride_bn
|
||||||
|
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||||
|
g_ptrs = g_ptr + offs_k
|
||||||
|
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||||
|
scales_ptrs = scales_ptr + offs_bn[None, :]
|
||||||
|
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
|
||||||
|
|
||||||
|
shifter = (offs_k % infearure_per_bits) * bits
|
||||||
|
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
||||||
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
|
||||||
|
for k in range(0, num_pid_k):
|
||||||
|
g_idx = tl.load(g_ptrs)
|
||||||
|
|
||||||
|
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||||
|
scales = tl.load(
|
||||||
|
scales_ptrs + g_idx[:, None] * stride_scales
|
||||||
|
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||||
|
zeros = tl.load(
|
||||||
|
zeros_ptrs + g_idx[:, None] * stride_zeros
|
||||||
|
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||||
|
|
||||||
|
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||||
|
zeros = zeros + 1
|
||||||
|
|
||||||
|
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||||
|
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||||
|
|
||||||
|
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||||
|
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||||
|
b = (b - zeros) * scales # Scale and shift
|
||||||
|
|
||||||
|
accumulator += tl.dot(a, b)
|
||||||
|
a_ptrs += BLOCK_SIZE_K
|
||||||
|
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
||||||
|
g_ptrs += BLOCK_SIZE_K
|
||||||
|
|
||||||
|
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
||||||
|
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
||||||
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||||
|
|
||||||
|
except:
|
||||||
|
print("triton not installed.")
|
||||||
|
|
||||||
|
|
||||||
|
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||||
|
with torch.cuda.device(input.device):
|
||||||
|
output = torch.empty(
|
||||||
|
(input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
|
||||||
|
)
|
||||||
|
grid = lambda META: (
|
||||||
|
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
|
||||||
|
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
|
||||||
|
)
|
||||||
|
matmul_248_kernel[grid](
|
||||||
|
input,
|
||||||
|
qweight,
|
||||||
|
output,
|
||||||
|
scales,
|
||||||
|
qzeros,
|
||||||
|
g_idx,
|
||||||
|
input.shape[0],
|
||||||
|
qweight.shape[1],
|
||||||
|
input.shape[1],
|
||||||
|
bits,
|
||||||
|
maxq,
|
||||||
|
input.stride(0),
|
||||||
|
input.stride(1),
|
||||||
|
qweight.stride(0),
|
||||||
|
qweight.stride(1),
|
||||||
|
output.stride(0),
|
||||||
|
output.stride(1),
|
||||||
|
scales.stride(0),
|
||||||
|
qzeros.stride(0),
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class QuantLinearFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
|
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||||
|
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class QuantLinear(nn.Module):
|
||||||
|
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("qweight", qweight)
|
||||||
|
self.register_buffer("qzeros", qzeros)
|
||||||
|
self.register_buffer("scales", scales)
|
||||||
|
self.register_buffer("g_idx", g_idx)
|
||||||
|
if bias is not None:
|
||||||
|
self.register_buffer("bias", bias)
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
if bits not in [2, 4, 8]:
|
||||||
|
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||||
|
self.bits = bits
|
||||||
|
self.maxq = 2**self.bits - 1
|
||||||
|
self.groupsize = groupsize
|
||||||
|
|
||||||
|
self.outfeatures = qweight.shape[1]
|
||||||
|
self.infeatures = qweight.shape[0] * 32 // 4
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
|
||||||
|
if bits not in [2, 4, 8]:
|
||||||
|
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||||
|
|
||||||
|
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
|
||||||
|
qzeros = torch.zeros(
|
||||||
|
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
scales = torch.zeros(
|
||||||
|
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
|
||||||
|
)
|
||||||
|
g_idx = torch.tensor(
|
||||||
|
[i // groupsize for i in range(infeatures)], dtype=torch.int32
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
bias = torch.zeros((outfeatures), dtype=torch.float16)
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||||
|
|
||||||
|
def pack(self, linear, scales, zeros, g_idx=None):
|
||||||
|
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
|
||||||
|
|
||||||
|
scales = scales.t().contiguous()
|
||||||
|
zeros = zeros.t().contiguous()
|
||||||
|
scale_zeros = zeros * scales
|
||||||
|
self.scales = scales.clone().half()
|
||||||
|
if linear.bias is not None:
|
||||||
|
self.bias = linear.bias.clone().half()
|
||||||
|
|
||||||
|
intweight = []
|
||||||
|
for idx in range(self.infeatures):
|
||||||
|
intweight.append(
|
||||||
|
torch.round(
|
||||||
|
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
|
||||||
|
/ self.scales[self.g_idx[idx]]
|
||||||
|
).to(torch.int)[:, None]
|
||||||
|
)
|
||||||
|
intweight = torch.cat(intweight, dim=1)
|
||||||
|
intweight = intweight.t().contiguous()
|
||||||
|
intweight = intweight.numpy().astype(np.uint32)
|
||||||
|
qweight = np.zeros(
|
||||||
|
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
|
||||||
|
)
|
||||||
|
i = 0
|
||||||
|
row = 0
|
||||||
|
while row < qweight.shape[0]:
|
||||||
|
if self.bits in [2, 4, 8]:
|
||||||
|
for j in range(i, i + (32 // self.bits)):
|
||||||
|
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||||
|
i += 32 // self.bits
|
||||||
|
row += 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||||
|
|
||||||
|
qweight = qweight.astype(np.int32)
|
||||||
|
self.qweight = torch.from_numpy(qweight)
|
||||||
|
|
||||||
|
zeros -= 1
|
||||||
|
zeros = zeros.numpy().astype(np.uint32)
|
||||||
|
qzeros = np.zeros(
|
||||||
|
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
|
||||||
|
)
|
||||||
|
i = 0
|
||||||
|
col = 0
|
||||||
|
while col < qzeros.shape[1]:
|
||||||
|
if self.bits in [2, 4, 8]:
|
||||||
|
for j in range(i, i + (32 // self.bits)):
|
||||||
|
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||||
|
i += 32 // self.bits
|
||||||
|
col += 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||||
|
|
||||||
|
qzeros = qzeros.astype(np.int32)
|
||||||
|
self.qzeros = torch.from_numpy(qzeros)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||||
|
out = QuantLinearFunction.apply(
|
||||||
|
x.reshape(-1, x.shape[-1]),
|
||||||
|
self.qweight,
|
||||||
|
self.scales,
|
||||||
|
self.qzeros,
|
||||||
|
self.g_idx,
|
||||||
|
self.bits,
|
||||||
|
self.maxq,
|
||||||
|
)
|
||||||
|
out = out + self.bias if self.bias is not None else out
|
||||||
|
return out.reshape(out_shape)
|
|
@ -0,0 +1,866 @@
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import math
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from texttable import Texttable
|
||||||
|
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
|
||||||
|
import transformers
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||||
|
from loguru import logger
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
DEV = torch.device("cuda:0")
|
||||||
|
|
||||||
|
|
||||||
|
class Quantizer(nn.Module):
|
||||||
|
def __init__(self, shape=1):
|
||||||
|
super(Quantizer, self).__init__()
|
||||||
|
self.register_buffer("maxq", torch.tensor(0))
|
||||||
|
self.register_buffer("scale", torch.zeros(shape))
|
||||||
|
self.register_buffer("zero", torch.zeros(shape))
|
||||||
|
|
||||||
|
def configure(
|
||||||
|
self,
|
||||||
|
bits,
|
||||||
|
perchannel=False,
|
||||||
|
sym=True,
|
||||||
|
mse=False,
|
||||||
|
norm=2.4,
|
||||||
|
grid=100,
|
||||||
|
maxshrink=0.8,
|
||||||
|
trits=False,
|
||||||
|
):
|
||||||
|
|
||||||
|
self.maxq = torch.tensor(2**bits - 1)
|
||||||
|
self.perchannel = perchannel
|
||||||
|
self.sym = sym
|
||||||
|
self.mse = mse
|
||||||
|
self.norm = norm
|
||||||
|
self.grid = grid
|
||||||
|
self.maxshrink = maxshrink
|
||||||
|
if trits:
|
||||||
|
self.maxq = torch.tensor(-1)
|
||||||
|
self.scale = torch.zeros_like(self.scale)
|
||||||
|
|
||||||
|
def _quantize(self, x, scale, zero, maxq):
|
||||||
|
if maxq < 0:
|
||||||
|
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
|
||||||
|
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
|
||||||
|
return scale * (q - zero)
|
||||||
|
|
||||||
|
def find_params(self, x, weight=False):
|
||||||
|
dev = x.device
|
||||||
|
self.maxq = self.maxq.to(dev)
|
||||||
|
|
||||||
|
shape = x.shape
|
||||||
|
if self.perchannel:
|
||||||
|
if weight:
|
||||||
|
x = x.flatten(1)
|
||||||
|
else:
|
||||||
|
if len(shape) == 4:
|
||||||
|
x = x.permute([1, 0, 2, 3])
|
||||||
|
x = x.flatten(1)
|
||||||
|
if len(shape) == 3:
|
||||||
|
x = x.reshape((-1, shape[-1])).t()
|
||||||
|
if len(shape) == 2:
|
||||||
|
x = x.t()
|
||||||
|
else:
|
||||||
|
x = x.flatten().unsqueeze(0)
|
||||||
|
|
||||||
|
tmp = torch.zeros(x.shape[0], device=dev)
|
||||||
|
xmin = torch.minimum(x.min(1)[0], tmp)
|
||||||
|
xmax = torch.maximum(x.max(1)[0], tmp)
|
||||||
|
|
||||||
|
if self.sym:
|
||||||
|
xmax = torch.maximum(torch.abs(xmin), xmax)
|
||||||
|
tmp = xmin < 0
|
||||||
|
if torch.any(tmp):
|
||||||
|
xmin[tmp] = -xmax[tmp]
|
||||||
|
tmp = (xmin == 0) & (xmax == 0)
|
||||||
|
xmin[tmp] = -1
|
||||||
|
xmax[tmp] = +1
|
||||||
|
|
||||||
|
if self.maxq < 0:
|
||||||
|
self.scale = xmax
|
||||||
|
self.zero = xmin
|
||||||
|
else:
|
||||||
|
self.scale = (xmax - xmin) / self.maxq
|
||||||
|
if self.sym:
|
||||||
|
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
|
||||||
|
else:
|
||||||
|
self.zero = torch.round(-xmin / self.scale)
|
||||||
|
|
||||||
|
if self.mse:
|
||||||
|
best = torch.full([x.shape[0]], float("inf"), device=dev)
|
||||||
|
for i in range(int(self.maxshrink * self.grid)):
|
||||||
|
p = 1 - i / self.grid
|
||||||
|
xmin1 = p * xmin
|
||||||
|
xmax1 = p * xmax
|
||||||
|
scale1 = (xmax1 - xmin1) / self.maxq
|
||||||
|
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
|
||||||
|
q = self._quantize(
|
||||||
|
x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq
|
||||||
|
)
|
||||||
|
q -= x
|
||||||
|
q.abs_()
|
||||||
|
q.pow_(self.norm)
|
||||||
|
err = torch.sum(q, 1)
|
||||||
|
tmp = err < best
|
||||||
|
if torch.any(tmp):
|
||||||
|
best[tmp] = err[tmp]
|
||||||
|
self.scale[tmp] = scale1[tmp]
|
||||||
|
self.zero[tmp] = zero1[tmp]
|
||||||
|
if not self.perchannel:
|
||||||
|
if weight:
|
||||||
|
tmp = shape[0]
|
||||||
|
else:
|
||||||
|
tmp = shape[1] if len(shape) != 3 else shape[2]
|
||||||
|
self.scale = self.scale.repeat(tmp)
|
||||||
|
self.zero = self.zero.repeat(tmp)
|
||||||
|
|
||||||
|
if weight:
|
||||||
|
shape = [-1] + [1] * (len(shape) - 1)
|
||||||
|
self.scale = self.scale.reshape(shape)
|
||||||
|
self.zero = self.zero.reshape(shape)
|
||||||
|
return
|
||||||
|
if len(shape) == 4:
|
||||||
|
self.scale = self.scale.reshape((1, -1, 1, 1))
|
||||||
|
self.zero = self.zero.reshape((1, -1, 1, 1))
|
||||||
|
if len(shape) == 3:
|
||||||
|
self.scale = self.scale.reshape((1, 1, -1))
|
||||||
|
self.zero = self.zero.reshape((1, 1, -1))
|
||||||
|
if len(shape) == 2:
|
||||||
|
self.scale = self.scale.unsqueeze(0)
|
||||||
|
self.zero = self.zero.unsqueeze(0)
|
||||||
|
|
||||||
|
def quantize(self, x):
|
||||||
|
if self.ready():
|
||||||
|
return self._quantize(x, self.scale, self.zero, self.maxq)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def enabled(self):
|
||||||
|
return self.maxq > 0
|
||||||
|
|
||||||
|
def ready(self):
|
||||||
|
return torch.all(self.scale != 0)
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQ:
|
||||||
|
def __init__(self, layer, observe=False):
|
||||||
|
self.layer = layer
|
||||||
|
self.dev = self.layer.weight.device
|
||||||
|
W = layer.weight.data.clone()
|
||||||
|
if isinstance(self.layer, nn.Conv2d):
|
||||||
|
W = W.flatten(1)
|
||||||
|
if isinstance(self.layer, transformers.Conv1D):
|
||||||
|
W = W.t()
|
||||||
|
self.rows = W.shape[0]
|
||||||
|
self.columns = W.shape[1]
|
||||||
|
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
|
||||||
|
self.nsamples = 0
|
||||||
|
self.quantizer = Quantizer()
|
||||||
|
self.observe = observe
|
||||||
|
|
||||||
|
def add_batch(self, inp, out):
|
||||||
|
# Hessian H = 2 X XT + λ I
|
||||||
|
if self.observe:
|
||||||
|
self.inp1 = inp
|
||||||
|
self.out1 = out
|
||||||
|
else:
|
||||||
|
self.inp1 = None
|
||||||
|
self.out1 = None
|
||||||
|
|
||||||
|
if len(inp.shape) == 2:
|
||||||
|
inp = inp.unsqueeze(0)
|
||||||
|
tmp = inp.shape[0]
|
||||||
|
if isinstance(self.layer, nn.Linear) or isinstance(
|
||||||
|
self.layer, transformers.Conv1D
|
||||||
|
):
|
||||||
|
if len(inp.shape) == 3:
|
||||||
|
inp = inp.reshape((-1, inp.shape[-1]))
|
||||||
|
inp = inp.t()
|
||||||
|
if isinstance(self.layer, nn.Conv2d):
|
||||||
|
unfold = nn.Unfold(
|
||||||
|
self.layer.kernel_size,
|
||||||
|
dilation=self.layer.dilation,
|
||||||
|
padding=self.layer.padding,
|
||||||
|
stride=self.layer.stride,
|
||||||
|
)
|
||||||
|
inp = unfold(inp)
|
||||||
|
inp = inp.permute([1, 0, 2])
|
||||||
|
inp = inp.flatten(1)
|
||||||
|
self.H *= self.nsamples / (self.nsamples + tmp)
|
||||||
|
self.nsamples += tmp
|
||||||
|
# inp = inp.float()
|
||||||
|
inp = math.sqrt(2 / self.nsamples) * inp.float()
|
||||||
|
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
|
||||||
|
self.H += inp.matmul(inp.t())
|
||||||
|
|
||||||
|
def print_loss(self, name, q_weight, weight_error, timecost):
|
||||||
|
table = Texttable()
|
||||||
|
length = 28
|
||||||
|
name = (
|
||||||
|
(name + " " * (length - len(name)))
|
||||||
|
if len(name) <= length
|
||||||
|
else name[:length]
|
||||||
|
)
|
||||||
|
|
||||||
|
table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"])
|
||||||
|
|
||||||
|
# assign weight
|
||||||
|
self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(
|
||||||
|
self.layer.weight.data.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.inp1 is not None:
|
||||||
|
# quantize input to int8
|
||||||
|
quantizer = Quantizer()
|
||||||
|
quantizer.configure(8, perchannel=False, sym=True, mse=False)
|
||||||
|
quantizer.find_params(self.inp1)
|
||||||
|
q_in = quantizer.quantize(self.inp1).type(torch.float16)
|
||||||
|
q_out = self.layer(q_in)
|
||||||
|
|
||||||
|
# get kinds of SNR
|
||||||
|
q_SNR = torch_snr_error(q_out, self.out1).item()
|
||||||
|
fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()
|
||||||
|
else:
|
||||||
|
q_SNR = "-"
|
||||||
|
fp_SNR = "-"
|
||||||
|
|
||||||
|
table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])
|
||||||
|
print(table.draw().split("\n")[-2])
|
||||||
|
|
||||||
|
def fasterquant(
|
||||||
|
self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name=""
|
||||||
|
):
|
||||||
|
self.layer.to(self.dev)
|
||||||
|
|
||||||
|
W = self.layer.weight.data.clone()
|
||||||
|
if isinstance(self.layer, nn.Conv2d):
|
||||||
|
W = W.flatten(1)
|
||||||
|
if isinstance(self.layer, transformers.Conv1D):
|
||||||
|
W = W.t()
|
||||||
|
W = W.float()
|
||||||
|
|
||||||
|
tick = time.time()
|
||||||
|
|
||||||
|
if not self.quantizer.ready():
|
||||||
|
self.quantizer.find_params(W, weight=True)
|
||||||
|
|
||||||
|
H = self.H
|
||||||
|
if not self.observe:
|
||||||
|
del self.H
|
||||||
|
dead = torch.diag(H) == 0
|
||||||
|
H[dead, dead] = 1
|
||||||
|
W[:, dead] = 0
|
||||||
|
|
||||||
|
if act_order:
|
||||||
|
perm = torch.argsort(torch.diag(H), descending=True)
|
||||||
|
W = W[:, perm]
|
||||||
|
H = H[perm][:, perm]
|
||||||
|
|
||||||
|
Losses = torch.zeros_like(W)
|
||||||
|
Q = torch.zeros_like(W)
|
||||||
|
|
||||||
|
damp = percdamp * torch.mean(torch.diag(H))
|
||||||
|
diag = torch.arange(self.columns, device=self.dev)
|
||||||
|
H[diag, diag] += damp
|
||||||
|
H = torch.linalg.cholesky(H)
|
||||||
|
H = torch.cholesky_inverse(H)
|
||||||
|
try:
|
||||||
|
H = torch.linalg.cholesky(H, upper=True)
|
||||||
|
except Exception:
|
||||||
|
# Addition because Falcon fails on h_to_4h
|
||||||
|
H = torch.linalg.cholesky(
|
||||||
|
H + 1e-5 * torch.eye(H.shape[0]).to(H.device), upper=True
|
||||||
|
)
|
||||||
|
Hinv = H
|
||||||
|
|
||||||
|
g_idx = []
|
||||||
|
scale = []
|
||||||
|
zero = []
|
||||||
|
now_idx = 1
|
||||||
|
|
||||||
|
for i1 in range(0, self.columns, blocksize):
|
||||||
|
i2 = min(i1 + blocksize, self.columns)
|
||||||
|
count = i2 - i1
|
||||||
|
|
||||||
|
W1 = W[:, i1:i2].clone()
|
||||||
|
Q1 = torch.zeros_like(W1)
|
||||||
|
Err1 = torch.zeros_like(W1)
|
||||||
|
Losses1 = torch.zeros_like(W1)
|
||||||
|
Hinv1 = Hinv[i1:i2, i1:i2]
|
||||||
|
|
||||||
|
for i in range(count):
|
||||||
|
w = W1[:, i]
|
||||||
|
d = Hinv1[i, i]
|
||||||
|
|
||||||
|
if groupsize != -1:
|
||||||
|
if (i1 + i) % groupsize == 0:
|
||||||
|
self.quantizer.find_params(
|
||||||
|
W[:, (i1 + i) : (i1 + i + groupsize)], weight=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if ((i1 + i) // groupsize) - now_idx == -1:
|
||||||
|
scale.append(self.quantizer.scale)
|
||||||
|
zero.append(self.quantizer.zero)
|
||||||
|
now_idx += 1
|
||||||
|
|
||||||
|
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
|
||||||
|
Q1[:, i] = q
|
||||||
|
Losses1[:, i] = (w - q) ** 2 / d**2
|
||||||
|
|
||||||
|
err1 = (w - q) / d
|
||||||
|
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
|
||||||
|
Err1[:, i] = err1
|
||||||
|
|
||||||
|
Q[:, i1:i2] = Q1
|
||||||
|
Losses[:, i1:i2] = Losses1 / 2
|
||||||
|
|
||||||
|
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
error = torch.sum(Losses).item()
|
||||||
|
|
||||||
|
groupsize = groupsize if groupsize != -1 else self.columns
|
||||||
|
g_idx = [i // groupsize for i in range(self.columns)]
|
||||||
|
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
|
||||||
|
if act_order:
|
||||||
|
invperm = torch.argsort(perm)
|
||||||
|
Q = Q[:, invperm]
|
||||||
|
g_idx = g_idx[invperm]
|
||||||
|
|
||||||
|
if isinstance(self.layer, transformers.Conv1D):
|
||||||
|
Q = Q.t()
|
||||||
|
|
||||||
|
self.print_loss(
|
||||||
|
name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)
|
||||||
|
)
|
||||||
|
|
||||||
|
if scale == []:
|
||||||
|
scale.append(self.quantizer.scale)
|
||||||
|
zero.append(self.quantizer.zero)
|
||||||
|
scale = torch.cat(scale, dim=1)
|
||||||
|
zero = torch.cat(zero, dim=1)
|
||||||
|
return scale, zero, g_idx, error
|
||||||
|
|
||||||
|
def free(self):
|
||||||
|
self.inp1 = None
|
||||||
|
self.out1 = None
|
||||||
|
self.H = None
|
||||||
|
self.Losses = None
|
||||||
|
self.Trace = None
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def get_wikitext2(nsamples, seed, seqlen, model_id):
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
|
||||||
|
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||||
|
trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
|
||||||
|
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
trainloader = []
|
||||||
|
for _ in range(nsamples):
|
||||||
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
||||||
|
j = i + seqlen
|
||||||
|
inp = trainenc.input_ids[:, i:j]
|
||||||
|
tar = inp.clone()
|
||||||
|
tar[:, :-1] = -100
|
||||||
|
trainloader.append((inp, tar))
|
||||||
|
return trainloader, testenc
|
||||||
|
|
||||||
|
|
||||||
|
def get_ptb(nsamples, seed, seqlen, model_id):
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
|
||||||
|
valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation")
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||||
|
except:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||||
|
trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt")
|
||||||
|
testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt")
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
trainloader = []
|
||||||
|
for _ in range(nsamples):
|
||||||
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
||||||
|
j = i + seqlen
|
||||||
|
inp = trainenc.input_ids[:, i:j]
|
||||||
|
tar = inp.clone()
|
||||||
|
tar[:, :-1] = -100
|
||||||
|
trainloader.append((inp, tar))
|
||||||
|
return trainloader, testenc
|
||||||
|
|
||||||
|
|
||||||
|
def get_c4(nsamples, seed, seqlen, model_id):
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
traindata = load_dataset(
|
||||||
|
"allenai/c4",
|
||||||
|
"allenai--c4",
|
||||||
|
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
|
||||||
|
split="train",
|
||||||
|
use_auth_token=False,
|
||||||
|
)
|
||||||
|
valdata = load_dataset(
|
||||||
|
"allenai/c4",
|
||||||
|
"allenai--c4",
|
||||||
|
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
|
||||||
|
split="validation",
|
||||||
|
use_auth_token=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||||
|
except:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
trainloader = []
|
||||||
|
for _ in range(nsamples):
|
||||||
|
while True:
|
||||||
|
i = random.randint(0, len(traindata) - 1)
|
||||||
|
trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
|
||||||
|
if trainenc.input_ids.shape[1] >= seqlen:
|
||||||
|
break
|
||||||
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
||||||
|
j = i + seqlen
|
||||||
|
inp = trainenc.input_ids[:, i:j]
|
||||||
|
tar = inp.clone()
|
||||||
|
tar[:, :-1] = -100
|
||||||
|
trainloader.append((inp, tar))
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
random.seed(0)
|
||||||
|
valenc = []
|
||||||
|
for _ in range(256):
|
||||||
|
while True:
|
||||||
|
i = random.randint(0, len(valdata) - 1)
|
||||||
|
tmp = tokenizer(valdata[i]["text"], return_tensors="pt")
|
||||||
|
if tmp.input_ids.shape[1] >= seqlen:
|
||||||
|
break
|
||||||
|
i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
|
||||||
|
j = i + seqlen
|
||||||
|
valenc.append(tmp.input_ids[:, i:j])
|
||||||
|
valenc = torch.hstack(valenc)
|
||||||
|
|
||||||
|
class TokenizerWrapper:
|
||||||
|
def __init__(self, input_ids):
|
||||||
|
self.input_ids = input_ids
|
||||||
|
|
||||||
|
valenc = TokenizerWrapper(valenc)
|
||||||
|
|
||||||
|
return trainloader, valenc
|
||||||
|
|
||||||
|
|
||||||
|
def get_ptb_new(nsamples, seed, seqlen, model_id):
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
|
||||||
|
testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||||
|
except:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||||
|
trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt")
|
||||||
|
testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt")
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
trainloader = []
|
||||||
|
for _ in range(nsamples):
|
||||||
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
||||||
|
j = i + seqlen
|
||||||
|
inp = trainenc.input_ids[:, i:j]
|
||||||
|
tar = inp.clone()
|
||||||
|
tar[:, :-1] = -100
|
||||||
|
trainloader.append((inp, tar))
|
||||||
|
return trainloader, testenc
|
||||||
|
|
||||||
|
|
||||||
|
def get_c4_new(nsamples, seed, seqlen, model_id):
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
traindata = load_dataset(
|
||||||
|
"allenai/c4",
|
||||||
|
"allenai--c4",
|
||||||
|
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
valdata = load_dataset(
|
||||||
|
"allenai/c4",
|
||||||
|
"allenai--c4",
|
||||||
|
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
|
||||||
|
split="validation",
|
||||||
|
)
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||||
|
except:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
trainloader = []
|
||||||
|
for _ in range(nsamples):
|
||||||
|
while True:
|
||||||
|
i = random.randint(0, len(traindata) - 1)
|
||||||
|
trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
|
||||||
|
if trainenc.input_ids.shape[1] >= seqlen:
|
||||||
|
break
|
||||||
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
||||||
|
j = i + seqlen
|
||||||
|
inp = trainenc.input_ids[:, i:j]
|
||||||
|
tar = inp.clone()
|
||||||
|
tar[:, :-1] = -100
|
||||||
|
trainloader.append((inp, tar))
|
||||||
|
|
||||||
|
valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt")
|
||||||
|
valenc = valenc.input_ids[:, : (256 * seqlen)]
|
||||||
|
|
||||||
|
class TokenizerWrapper:
|
||||||
|
def __init__(self, input_ids):
|
||||||
|
self.input_ids = input_ids
|
||||||
|
|
||||||
|
valenc = TokenizerWrapper(valenc)
|
||||||
|
|
||||||
|
return trainloader, valenc
|
||||||
|
|
||||||
|
|
||||||
|
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=""):
|
||||||
|
if "wikitext2" in name:
|
||||||
|
return get_wikitext2(nsamples, seed, seqlen, model_id)
|
||||||
|
if "ptb" in name:
|
||||||
|
if "new" in name:
|
||||||
|
return get_ptb_new(nsamples, seed, seqlen, model_id)
|
||||||
|
return get_ptb(nsamples, seed, seqlen, model_id)
|
||||||
|
if "c4" in name:
|
||||||
|
if "new" in name:
|
||||||
|
return get_c4_new(nsamples, seed, seqlen, model_id)
|
||||||
|
return get_c4(nsamples, seed, seqlen, model_id)
|
||||||
|
|
||||||
|
|
||||||
|
def find_layers(module, layers=(nn.Conv2d, nn.Linear), name=""):
|
||||||
|
# Skip last lm_head linear
|
||||||
|
# Need isintance Falcon is inheriting Linear.
|
||||||
|
if isinstance(module, layers) and "lm_head" not in name:
|
||||||
|
return {name: module}
|
||||||
|
res = {}
|
||||||
|
for name1, child in module.named_children():
|
||||||
|
res.update(
|
||||||
|
find_layers(
|
||||||
|
child, layers=layers, name=name + "." + name1 if name != "" else name1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sequential(
|
||||||
|
model,
|
||||||
|
dataloader,
|
||||||
|
dev,
|
||||||
|
nsamples,
|
||||||
|
bits,
|
||||||
|
groupsize,
|
||||||
|
percdamp=0.01,
|
||||||
|
sym: bool = False,
|
||||||
|
act_order: bool = False,
|
||||||
|
):
|
||||||
|
print("Starting ...")
|
||||||
|
|
||||||
|
use_cache = model.config.use_cache
|
||||||
|
model.config.use_cache = False
|
||||||
|
try:
|
||||||
|
layers = model.model.layers
|
||||||
|
prefix = "model.layers"
|
||||||
|
except Exception:
|
||||||
|
layers = model.transformer.h
|
||||||
|
prefix = "transformer.h"
|
||||||
|
|
||||||
|
dtype = next(iter(model.parameters())).dtype
|
||||||
|
inps = torch.zeros(
|
||||||
|
(nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
|
||||||
|
)
|
||||||
|
|
||||||
|
cache = {"i": 0}
|
||||||
|
extra = {}
|
||||||
|
|
||||||
|
class Catcher(nn.Module):
|
||||||
|
def __init__(self, module):
|
||||||
|
super().__init__()
|
||||||
|
self.module = module
|
||||||
|
|
||||||
|
def forward(self, inp, **kwargs):
|
||||||
|
inps[cache["i"]] = inp
|
||||||
|
cache["i"] += 1
|
||||||
|
extra.update(kwargs.copy())
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
layers[0] = Catcher(layers[0])
|
||||||
|
for batch in dataloader:
|
||||||
|
try:
|
||||||
|
model(batch[0])
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
layers[0] = layers[0].module
|
||||||
|
|
||||||
|
# layers[0] = layers[0].cpu()
|
||||||
|
# model.model.embed_tokens = model.model.embed_tokens.cpu()
|
||||||
|
# model.model.norm = model.model.norm.cpu()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
outs = torch.zeros_like(inps)
|
||||||
|
|
||||||
|
extra = {
|
||||||
|
k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in extra.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
print("Ready.")
|
||||||
|
|
||||||
|
quantizers = {}
|
||||||
|
for i in range(len(layers)):
|
||||||
|
print(f"Quantizing layer {i+1}/{len(layers)}..")
|
||||||
|
print("+------------------+--------------+------------+-----------+-------+")
|
||||||
|
print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
|
||||||
|
print("+==================+==============+============+===========+=======+")
|
||||||
|
|
||||||
|
from accelerate.hooks import remove_hook_from_submodules
|
||||||
|
|
||||||
|
layer = layers[i].to(dev)
|
||||||
|
remove_hook_from_submodules(layer)
|
||||||
|
full = find_layers(layer)
|
||||||
|
sequential = [list(full.keys())]
|
||||||
|
|
||||||
|
for names in sequential:
|
||||||
|
subset = {n: full[n] for n in names}
|
||||||
|
gptq = {}
|
||||||
|
for name in subset:
|
||||||
|
gptq[name] = GPTQ(subset[name])
|
||||||
|
gptq[name].quantizer.configure(
|
||||||
|
bits, perchannel=True, sym=sym, mse=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_batch(name):
|
||||||
|
def tmp(_, inp, out):
|
||||||
|
gptq[name].add_batch(inp[0].data, out.data)
|
||||||
|
|
||||||
|
return tmp
|
||||||
|
|
||||||
|
handles = []
|
||||||
|
for name in subset:
|
||||||
|
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
||||||
|
for j in range(nsamples):
|
||||||
|
|
||||||
|
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
||||||
|
for h in handles:
|
||||||
|
h.remove()
|
||||||
|
|
||||||
|
for name in subset:
|
||||||
|
scale, zero, g_idx, error = gptq[name].fasterquant(
|
||||||
|
percdamp=percdamp,
|
||||||
|
groupsize=groupsize,
|
||||||
|
act_order=act_order,
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
quantizers[f"{prefix}.{i}.{name}"] = (
|
||||||
|
gptq[name].quantizer.cpu(),
|
||||||
|
scale.cpu(),
|
||||||
|
zero.cpu(),
|
||||||
|
g_idx.cpu(),
|
||||||
|
bits,
|
||||||
|
groupsize,
|
||||||
|
)
|
||||||
|
|
||||||
|
gptq[name].free()
|
||||||
|
|
||||||
|
for j in range(nsamples):
|
||||||
|
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
||||||
|
|
||||||
|
layers[i] = layer.cpu()
|
||||||
|
del layer
|
||||||
|
del gptq
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
inps, outs = outs, inps
|
||||||
|
print("+------------------+--------------+------------+-----------+-------+")
|
||||||
|
print("\n")
|
||||||
|
|
||||||
|
model.config.use_cache = use_cache
|
||||||
|
|
||||||
|
return quantizers
|
||||||
|
|
||||||
|
|
||||||
|
def make_quant_linear(module, names, bits, groupsize, name=""):
|
||||||
|
if isinstance(module, QuantLinear):
|
||||||
|
return
|
||||||
|
for attr in dir(module):
|
||||||
|
tmp = getattr(module, attr)
|
||||||
|
name1 = name + "." + attr if name != "" else attr
|
||||||
|
if name1 in names:
|
||||||
|
delattr(module, attr)
|
||||||
|
setattr(
|
||||||
|
module,
|
||||||
|
attr,
|
||||||
|
QuantLinear.new(
|
||||||
|
bits,
|
||||||
|
groupsize,
|
||||||
|
tmp.in_features,
|
||||||
|
tmp.out_features,
|
||||||
|
tmp.bias is not None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for name1, child in module.named_children():
|
||||||
|
make_quant_linear(
|
||||||
|
child, names, bits, groupsize, name + "." + name1 if name != "" else name1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: perform packing on GPU
|
||||||
|
def pack(model, quantizers, bits, groupsize):
|
||||||
|
layers = find_layers(model)
|
||||||
|
layers = {n: layers[n] for n in quantizers}
|
||||||
|
make_quant_linear(model, quantizers, bits, groupsize)
|
||||||
|
qlayers = find_layers(model, (QuantLinear,))
|
||||||
|
print("Packing ...")
|
||||||
|
for name in qlayers:
|
||||||
|
print(name)
|
||||||
|
quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
|
||||||
|
qlayers[name].pack(layers[name], scale, zero, g_idx)
|
||||||
|
print("Done.")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def quantize(
|
||||||
|
model_id: str,
|
||||||
|
bits: int,
|
||||||
|
groupsize: int,
|
||||||
|
output_dir: str,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
upload_to_model_id: Optional[str],
|
||||||
|
percdamp: float,
|
||||||
|
act_order: bool,
|
||||||
|
):
|
||||||
|
print("loading model")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map="balanced_low_0",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
print("LOADED model")
|
||||||
|
model.seqlen = 2048
|
||||||
|
|
||||||
|
dataset = "wikitext2"
|
||||||
|
nsamples = 128
|
||||||
|
seed = None
|
||||||
|
|
||||||
|
dataloader, testloader = get_loaders(
|
||||||
|
dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen
|
||||||
|
)
|
||||||
|
|
||||||
|
tick = time.time()
|
||||||
|
quantizers = sequential(
|
||||||
|
model,
|
||||||
|
dataloader,
|
||||||
|
DEV,
|
||||||
|
nsamples,
|
||||||
|
bits,
|
||||||
|
groupsize,
|
||||||
|
percdamp=percdamp,
|
||||||
|
act_order=act_order,
|
||||||
|
)
|
||||||
|
print(time.time() - tick)
|
||||||
|
|
||||||
|
pack(model, quantizers, bits, groupsize)
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
from transformers.modeling_utils import shard_checkpoint
|
||||||
|
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
||||||
|
state_dict["gptq_bits"] = torch.LongTensor([bits])
|
||||||
|
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
|
||||||
|
|
||||||
|
max_shard_size = "10GB"
|
||||||
|
shards, index = shard_checkpoint(
|
||||||
|
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors"
|
||||||
|
)
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
for shard_file, shard in shards.items():
|
||||||
|
save_file(
|
||||||
|
shard,
|
||||||
|
os.path.join(output_dir, shard_file),
|
||||||
|
metadata={
|
||||||
|
"format": "pt",
|
||||||
|
"quantized": "gptq",
|
||||||
|
"origin": "text-generation-inference",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if index is None:
|
||||||
|
path_to_weights = os.path.join(output_dir, "model.safetensors")
|
||||||
|
logger.info(f"Model weights saved in {path_to_weights}")
|
||||||
|
else:
|
||||||
|
save_index_file = "model.safetensors.index.json"
|
||||||
|
save_index_file = os.path.join(output_dir, save_index_file)
|
||||||
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||||
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||||
|
f.write(content)
|
||||||
|
logger.info(
|
||||||
|
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||||
|
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||||
|
f"index located at {save_index_file}."
|
||||||
|
)
|
||||||
|
config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||||
|
config.save_pretrained(output_dir)
|
||||||
|
logger.info("Saved config")
|
||||||
|
logger.info("Saving tokenizer")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_id, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
logger.info("Saved tokenizer")
|
||||||
|
|
||||||
|
if upload_to_model_id:
|
||||||
|
|
||||||
|
api = HfApi()
|
||||||
|
|
||||||
|
api.upload_folder(
|
||||||
|
folder_path=output_dir, repo_id=upload_to_model_id, repo_type="model"
|
||||||
|
)
|
|
@ -15,6 +15,8 @@ except ImportError:
|
||||||
|
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
|
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||||
|
|
||||||
|
|
||||||
# Monkey patching
|
# Monkey patching
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -129,7 +131,22 @@ def get_linear(weight, bias, quantize):
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
linear.bias = nn.Parameter(bias)
|
linear.bias = nn.Parameter(bias)
|
||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
raise NotImplementedError("Soon")
|
try:
|
||||||
|
qweight, qzeros, scales, g_idx, bits, groupsize = weight
|
||||||
|
except Exception:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||||
|
)
|
||||||
|
|
||||||
|
linear = QuantLinear(
|
||||||
|
qweight,
|
||||||
|
qzeros,
|
||||||
|
scales,
|
||||||
|
g_idx,
|
||||||
|
bias,
|
||||||
|
bits,
|
||||||
|
groupsize,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||||
return linear
|
return linear
|
||||||
|
@ -152,8 +169,14 @@ class TensorParallelHead(SuperLayer):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(config, prefix: str, weights):
|
def load(config, prefix: str, weights):
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||||
|
|
||||||
|
# GPTQ doesn't quantize heads (nor embeddings)
|
||||||
|
if config.quantize == "gptq":
|
||||||
|
quantize = None
|
||||||
|
else:
|
||||||
|
quantize = config.quantize
|
||||||
return TensorParallelHead(
|
return TensorParallelHead(
|
||||||
get_linear(weight, bias=None, quantize=config.quantize),
|
get_linear(weight, bias=None, quantize=quantize),
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -196,24 +219,21 @@ class TensorParallelHead(SuperLayer):
|
||||||
class TensorParallelColumnLinear(SuperLayer):
|
class TensorParallelColumnLinear(SuperLayer):
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, config, prefix: str, weights, bias: bool):
|
def load(cls, config, prefix: str, weights, bias: bool):
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
return cls.load_multi(config, [prefix], weights, bias, dim=0)
|
||||||
if bias:
|
|
||||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
|
||||||
else:
|
|
||||||
bias = None
|
|
||||||
return cls(get_linear(weight, bias, config.quantize))
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
||||||
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
weight = weights.get_multi_weights_col(
|
||||||
weight = torch.cat(w, dim=dim)
|
prefixes, quantize=config.quantize, dim=dim
|
||||||
|
)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
||||||
bias = torch.cat(b, dim=0)
|
bias = torch.cat(b, dim=dim)
|
||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
return cls(get_linear(weight, bias, config.quantize))
|
linear = get_linear(weight, bias, config.quantize)
|
||||||
|
return cls(linear)
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelRowLinear(SuperLayer):
|
class TensorParallelRowLinear(SuperLayer):
|
||||||
|
@ -223,7 +243,8 @@ class TensorParallelRowLinear(SuperLayer):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, config, prefix: str, weights, bias: bool):
|
def load(cls, config, prefix: str, weights, bias: bool):
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||||
|
|
||||||
if bias and weights.process_group.rank() == 0:
|
if bias and weights.process_group.rank() == 0:
|
||||||
# Rank is only on the first rank process
|
# Rank is only on the first rank process
|
||||||
bias = weights.get_tensor(f"{prefix}.bias")
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
|
@ -54,7 +55,10 @@ class Weights:
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
tensor = f.get_tensor(tensor_name)
|
tensor = f.get_tensor(tensor_name)
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
# Special case for gptq which shouldn't convert
|
||||||
|
# u4 which are disguised as int32
|
||||||
|
if tensor.dtype not in [torch.int32, torch.int64]:
|
||||||
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
@ -80,6 +84,49 @@ class Weights:
|
||||||
tensor = slice_[:, start:stop]
|
tensor = slice_[:, start:stop]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Let's make that generic when needed")
|
raise NotImplementedError("Let's make that generic when needed")
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
# Special case for gptq which shouldn't convert
|
||||||
|
# u4 which are disguised as int32
|
||||||
|
if tensor.dtype != torch.int32:
|
||||||
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||||
|
if quantize == "gptq":
|
||||||
|
try:
|
||||||
|
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
||||||
|
|
||||||
|
qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1)
|
||||||
|
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1)
|
||||||
|
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
|
for w2 in w[1:]:
|
||||||
|
torch.testing.assert_close(w2, w[0])
|
||||||
|
g_idx = w[0]
|
||||||
|
|
||||||
|
bits = self.get_tensor("gptq_bits").item()
|
||||||
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||||
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||||
|
else:
|
||||||
|
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
|
weight = torch.cat(w, dim=dim)
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||||
|
if quantize == "gptq":
|
||||||
|
try:
|
||||||
|
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
||||||
|
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||||
|
scales = self.get_tensor(f"{prefix}.scales")
|
||||||
|
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
|
|
||||||
|
bits = self.get_tensor("gptq_bits").item()
|
||||||
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||||
|
|
||||||
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||||
|
else:
|
||||||
|
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
|
return weight
|
||||||
|
|
Loading…
Reference in New Issue