feat(server): Using `quantize_config.json` instead of GPTQ_BITS env variables. (#671)

- Current PR is not great because we're side stepping the
  `Weights.__init__` but Weights shouldn't requires anything related
  to the config or the model_id as it aims to be a simple Wrapper
  over multi file loading.
- Ideal solution would be to use something like Rust enum
  ```
  enum Quantize{
    Bitandbytes(Bitsandbytes),
    GPTQ(bits: usize, groupsize: usize)
  ```
  And passing that around during load. Unfortunately we don't
  have access to this, so for now, side-stepping seems easier.

- Re-enabling groupsize<0 with exllama (confirmed it works.)

Helps #601 

In next steps we should make sure our quantization script uses that
format and make it standard.


# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
This commit is contained in:
Nicolas Patry 2023-07-25 12:00:27 +01:00 committed by GitHub
parent 37df6df38e
commit a0d55358d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 53 additions and 24 deletions

View File

@ -76,6 +76,8 @@ class BLOOMSharded(CausalLM):
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group filenames, device=device, dtype=dtype, process_group=self.process_group
) )
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
model = BloomForCausalLM(config, weights) model = BloomForCausalLM(config, weights)

View File

@ -76,7 +76,7 @@ def _load_multi_mqa_gptq(
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device) g_idx = g_idx.to(device=weights.device)
bits, groupsize = weights._get_gptq_qparams() bits, groupsize = weights._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA from text_generation_server.utils.layers import HAS_EXLLAMA

View File

@ -130,17 +130,17 @@ class OPTAttention(nn.Module):
process_group=None, process_group=None,
): ):
super().__init__() super().__init__()
embed_dim = config.embed_dim hidden_size = config.hidden_size
num_heads = config.num_attention_heads num_heads = config.num_attention_heads
self.embed_dim = embed_dim self.hidden_size = hidden_size
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = config.dropout self.dropout = config.dropout
self.head_dim = embed_dim // num_heads self.head_dim = hidden_size // num_heads
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError( raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads})." f" and `num_heads`: {num_heads})."
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
@ -153,7 +153,7 @@ class OPTAttention(nn.Module):
f"and `num_shards`: {weights.process_group.size()}" f"and `num_shards`: {weights.process_group.size()}"
) )
self.num_heads = self.num_heads // process_group.size() self.num_heads = self.num_heads // process_group.size()
self.embed_dim = self.embed_dim // process_group.size() self.hidden_size = self.hidden_size // process_group.size()
self.q_proj = TensorParallelColumnLinear.load( self.q_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias
@ -300,9 +300,9 @@ class OPTAttention(nn.Module):
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # Use the `hidden_size` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism. # partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = attn_output.reshape(bsz, tgt_len, self.hidden_size)
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
@ -313,7 +313,7 @@ class OPTDecoderLayer(nn.Module):
def __init__(self, layer_id: int, config: OPTConfig, weights): def __init__(self, layer_id: int, config: OPTConfig, weights):
super().__init__() super().__init__()
self.process_group = weights.process_group self.process_group = weights.process_group
self.embed_dim = config.hidden_size self.hidden_size = config.hidden_size
prefix = f"model.decoder.layers.{layer_id}" prefix = f"model.decoder.layers.{layer_id}"
self.self_attn = OPTAttention( self.self_attn = OPTAttention(
config, config,
@ -352,7 +352,7 @@ class OPTDecoderLayer(nn.Module):
]: ]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size

View File

@ -55,13 +55,15 @@ class FlashLlama(FlashCausalLM):
config = LlamaConfig.from_pretrained( config = LlamaConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
config.quantize = quantize
model = FlashLlamaForCausalLM(config, weights) model = FlashLlamaForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -52,6 +52,8 @@ class FlashNeoXSharded(FlashCausalLM):
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group filenames, device=device, dtype=dtype, process_group=self.process_group
) )
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
model = FlashGPTNeoXForCausalLM(config, weights) model = FlashGPTNeoXForCausalLM(config, weights)

View File

@ -58,6 +58,9 @@ class FlashRWSharded(FlashCausalLM):
) )
config.quantize = quantize config.quantize = quantize
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
model = FlashRWForCausalLM(config, weights) model = FlashRWForCausalLM(config, weights)

View File

@ -4,7 +4,10 @@ import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig from transformers import AutoTokenizer, AutoConfig
from typing import Optional, List from typing import Optional, List
import json
import os
from huggingface_hub import hf_hub_download
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM, FlashSantacoderForCausalLM,
@ -59,6 +62,8 @@ class FlashSantacoderSharded(FlashCausalLM):
process_group=self.process_group, process_group=self.process_group,
aliases={"transformer.wte.weight": ["lm_head.weight"]}, aliases={"transformer.wte.weight": ["lm_head.weight"]},
) )
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
model = FlashSantacoderForCausalLM(config, weights) model = FlashSantacoderForCausalLM(config, weights)

View File

@ -191,6 +191,8 @@ class GalacticaSharded(CausalLM):
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group filenames, device=device, dtype=dtype, process_group=self.process_group
) )
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
model = OPTForCausalLM(config, weights) model = OPTForCausalLM(config, weights)

View File

@ -56,6 +56,8 @@ class GPTNeoxSharded(CausalLM):
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group filenames, device=device, dtype=dtype, process_group=self.process_group
) )
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
model = GPTNeoxForCausalLM(config, weights) model = GPTNeoxForCausalLM(config, weights)

View File

@ -78,6 +78,8 @@ class MPTSharded(CausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
config.quantize = quantize config.quantize = quantize
model = MPTForCausalLM(config, weights) model = MPTForCausalLM(config, weights)

View File

@ -54,6 +54,8 @@ class OPTSharded(CausalLM):
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group filenames, device=device, dtype=dtype, process_group=self.process_group
) )
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
model = OPTForCausalLM(config, weights) model = OPTForCausalLM(config, weights)

View File

@ -3,6 +3,8 @@ from typing import List, Dict, Optional, Tuple
from safetensors import safe_open, SafetensorError from safetensors import safe_open, SafetensorError
import torch import torch
from loguru import logger from loguru import logger
from huggingface_hub import hf_hub_download
import json
class Weights: class Weights:
@ -128,7 +130,7 @@ class Weights:
torch.testing.assert_close(w2, w[0]) torch.testing.assert_close(w2, w[0])
g_idx = w[0] g_idx = w[0]
bits, groupsize = self._get_gptq_qparams() bits, groupsize = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else: else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
@ -138,7 +140,7 @@ class Weights:
def get_multi_weights_row(self, prefix: str, quantize: str): def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq": if quantize == "gptq":
use_exllama = True use_exllama = True
bits, groupsize = self._get_gptq_qparams() bits, groupsize = self._get_gptq_params()
if bits != 4: if bits != 4:
use_exllama = False use_exllama = False
@ -185,11 +187,8 @@ class Weights:
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0) scales = self.get_sharded(f"{prefix}.scales", dim=0)
else: else:
raise RuntimeError( qzeros = self.get_tensor(f"{prefix}.qzeros")
"Using exllama GPTQ kernel with groupsize<1 is not supported" scales = self.get_tensor(f"{prefix}.scales")
)
# qzeros = self.get_tensor(f"{prefix}.qzeros")
# scales = self.get_tensor(f"{prefix}.scales")
# For tp > 1, at this point we know we do not use act-order # For tp > 1, at this point we know we do not use act-order
if self.process_group.size() == 1: if self.process_group.size() == 1:
@ -208,17 +207,25 @@ class Weights:
weight = self.get_sharded(f"{prefix}.weight", dim=1) weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight return weight
def _get_gptq_qparams(self) -> Tuple[int, int]: def _get_gptq_params(self) -> Tuple[int, int]:
try: try:
bits = self.get_tensor("gptq_bits").item() bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item() groupsize = self.get_tensor("gptq_groupsize").item()
except (SafetensorError, RuntimeError) as e: except (SafetensorError, RuntimeError) as e:
try: try:
import os bits = self.gptq_bits
groupsize = self.gptq_groupsize
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception: except Exception:
raise e raise e
return bits, groupsize return bits, groupsize
def _set_gptq_params(self, model_id):
try:
filename = hf_hub_download(model_id, filename="quantize_config.json")
with open(filename, "r") as f:
data = json.load(f)
self.gptq_bits = data["bits"]
self.gptq_groupsize = data["group_size"]
except Exception:
pass