2024-07-20 11:02:04 -06:00
|
|
|
import torch
|
|
|
|
|
2024-07-09 12:04:03 -06:00
|
|
|
from abc import ABC, abstractmethod
|
2024-07-15 23:58:25 -06:00
|
|
|
from contextlib import contextmanager
|
2023-06-08 06:51:52 -06:00
|
|
|
from pathlib import Path
|
2024-07-20 11:02:04 -06:00
|
|
|
from typing import Dict, List, Optional, Union, Type
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
from safetensors import safe_open
|
2024-07-20 11:02:04 -06:00
|
|
|
from dataclasses import dataclass
|
|
|
|
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
2024-07-09 12:04:03 -06:00
|
|
|
|
|
|
|
|
|
|
|
class WeightsLoader(ABC):
|
|
|
|
"""
|
|
|
|
Instances of this type implement higher-level weight loading.
|
|
|
|
|
|
|
|
At a low-level, every weight is stored in the Safetensors format.
|
|
|
|
The interpretation of weights may be different however, for instance
|
|
|
|
could be packed, quantized weights. Loaders are responsible for
|
|
|
|
interpreting the raw tensors, sharding tensors in a manner compatible
|
|
|
|
with the format, etc.
|
|
|
|
"""
|
|
|
|
|
2024-07-19 09:23:20 -06:00
|
|
|
@abstractmethod
|
|
|
|
def get_weights(self, weights: "Weights", prefix: str):
|
|
|
|
"""
|
|
|
|
Get weights at the given prefix and apply without tensor paralllism.
|
|
|
|
"""
|
|
|
|
...
|
|
|
|
|
2024-07-09 12:04:03 -06:00
|
|
|
@abstractmethod
|
|
|
|
def get_weights_col_packed(
|
|
|
|
self,
|
|
|
|
weights: "Weights",
|
|
|
|
prefix: str,
|
|
|
|
block_sizes: Union[int, List[int]],
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Get the packed weights at the given prefix with column-splitting for
|
|
|
|
tensor parallelism. This method should be used when multiple different
|
|
|
|
weights are packed into a tensor, for instance, query/key/value
|
|
|
|
weights or a gate/up projection.
|
|
|
|
|
|
|
|
The `block_sizes` determines the proportions of the packed tensors.
|
|
|
|
The columns are split in equally sized blocks when `block_sizes` is an
|
|
|
|
`int`, or in blocks proportional given to the sizes. For instance
|
|
|
|
`[2, 1, 1]` will divide an input with dimensionality `1024` in
|
|
|
|
`[512, 256, 256]`.
|
|
|
|
"""
|
|
|
|
...
|
|
|
|
|
|
|
|
def get_weights_col(self, weights: "Weights", prefix: str):
|
|
|
|
"""
|
|
|
|
Get weights at the given prefix and apply column-splitting for tensor
|
|
|
|
paralllism.
|
|
|
|
"""
|
|
|
|
return weights.get_multi_weights_col([prefix], 0)
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
|
|
|
"""
|
|
|
|
Get the weights at the given prefixes, column-split them for tensor
|
|
|
|
parallelim, and then concatenate the weights along the given dimension.
|
|
|
|
"""
|
|
|
|
...
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def get_weights_row(self, weights: "Weights", prefix: str):
|
|
|
|
"""
|
|
|
|
Get the weights at the given prefix and apply row-splitting for tensor
|
|
|
|
parallism.
|
|
|
|
"""
|
|
|
|
...
|
|
|
|
|
|
|
|
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
class Weight(ABC):
|
|
|
|
"""Instances of this type implement unquantized/quantized/to-be
|
|
|
|
quantized weights."""
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def get_linear(self, bias: torch.Tensor):
|
|
|
|
"""Create a linear layer from this weight."""
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
2024-07-20 11:02:04 -06:00
|
|
|
class UnquantizedWeight(Weight):
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
weight: torch.Tensor
|
|
|
|
|
|
|
|
def get_linear(self, bias: torch.Tensor):
|
|
|
|
from text_generation_server.layers.linear import FastLinear, FastLinearROCm
|
|
|
|
|
|
|
|
if SYSTEM == "rocm":
|
|
|
|
return FastLinearROCm(self.weight, bias)
|
|
|
|
else:
|
|
|
|
return FastLinear(self.weight, bias)
|
|
|
|
|
|
|
|
|
2024-07-09 12:04:03 -06:00
|
|
|
class DefaultWeightsLoader(WeightsLoader):
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
"""Weight loader that loads (unquantized) Torch tensors."""
|
|
|
|
|
2024-07-20 11:02:04 -06:00
|
|
|
def __init__(self, weight_class: Type[UnquantizedWeight]):
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
"""Create a loader. Weights will be wrapped using the given `weights_class`,
|
|
|
|
normally this will be `UnquantizedWeight`, but a quantizer-specific class
|
|
|
|
such as `Fp8Weight` can be used to quantize the weights during loading.
|
|
|
|
"""
|
|
|
|
self.weight_class = weight_class
|
|
|
|
|
2024-07-09 12:04:03 -06:00
|
|
|
"""
|
|
|
|
Loader that uses tensors as-is with the exception of applying sharding
|
|
|
|
and/or concatenation.
|
|
|
|
"""
|
|
|
|
|
2024-07-19 09:23:20 -06:00
|
|
|
def get_weights(self, weights: "Weights", prefix: str):
|
|
|
|
return weights.get_tensor(f"{prefix}.weight")
|
|
|
|
|
2024-07-09 12:04:03 -06:00
|
|
|
def get_weights_col_packed(
|
|
|
|
self,
|
|
|
|
weights: "Weights",
|
|
|
|
prefix: str,
|
|
|
|
block_sizes: Union[int, List[int]],
|
|
|
|
):
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
return self.weight_class(
|
|
|
|
weights.get_packed_sharded(
|
|
|
|
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
|
|
|
),
|
2024-07-09 12:04:03 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
|
|
|
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
return self.weight_class(torch.cat(w, dim=dim))
|
2024-07-09 12:04:03 -06:00
|
|
|
|
|
|
|
def get_weights_row(self, weights: "Weights", prefix: str):
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
return self.weight_class(
|
|
|
|
weights.get_sharded(f"{prefix}.weight", dim=1),
|
|
|
|
)
|
2023-06-08 06:51:52 -06:00
|
|
|
|
|
|
|
|
|
|
|
class Weights:
|
2023-06-30 11:09:59 -06:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
filenames: List[Path],
|
|
|
|
device,
|
|
|
|
dtype,
|
|
|
|
process_group,
|
2024-07-09 12:04:03 -06:00
|
|
|
weights_loader: WeightsLoader,
|
2023-06-30 11:09:59 -06:00
|
|
|
aliases: Optional[Dict[str, List[str]]] = None,
|
2023-12-11 06:49:52 -07:00
|
|
|
prefix: Optional[str] = None,
|
2023-06-30 11:09:59 -06:00
|
|
|
):
|
2023-06-08 06:51:52 -06:00
|
|
|
routing = {}
|
|
|
|
for filename in filenames:
|
|
|
|
with safe_open(filename, framework="pytorch") as f:
|
|
|
|
for k in f.keys():
|
|
|
|
if k in routing:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
|
|
|
)
|
|
|
|
routing[k] = filename
|
2023-06-23 04:40:46 -06:00
|
|
|
if aliases is None:
|
|
|
|
aliases = {}
|
|
|
|
self.aliases = aliases
|
2023-06-08 06:51:52 -06:00
|
|
|
self.routing = routing
|
|
|
|
self.device = device
|
|
|
|
self.dtype = dtype
|
|
|
|
self.process_group = process_group
|
2023-10-03 03:55:10 -06:00
|
|
|
self.prefix = prefix
|
2024-07-09 12:04:03 -06:00
|
|
|
self.weights_loader = weights_loader
|
2023-06-08 06:51:52 -06:00
|
|
|
self._handles = {}
|
|
|
|
|
|
|
|
def _get_handle(self, filename):
|
|
|
|
if filename not in self._handles:
|
|
|
|
f = safe_open(filename, framework="pytorch")
|
|
|
|
self._handles[filename] = f
|
|
|
|
|
|
|
|
return self._handles[filename]
|
|
|
|
|
2023-06-23 04:40:46 -06:00
|
|
|
def get_filename(self, tensor_name: str) -> (str, str):
|
2023-10-03 03:55:10 -06:00
|
|
|
names = [tensor_name]
|
|
|
|
if self.prefix is not None:
|
|
|
|
prefixed = f"{self.prefix}.{tensor_name}"
|
|
|
|
names.append(prefixed)
|
|
|
|
for name in names:
|
|
|
|
filename = self.routing.get(name, None)
|
|
|
|
if filename is not None:
|
|
|
|
return str(filename), name
|
|
|
|
|
|
|
|
aliases = self.aliases.get(name, [])
|
2023-06-23 04:40:46 -06:00
|
|
|
for alias in aliases:
|
|
|
|
filename = self.routing.get(alias, None)
|
|
|
|
if filename is not None:
|
|
|
|
return str(filename), alias
|
2023-10-03 03:55:10 -06:00
|
|
|
raise RuntimeError(f"weight {tensor_name} does not exist")
|
2023-06-08 06:51:52 -06:00
|
|
|
|
|
|
|
def _get_slice(self, tensor_name: str):
|
2023-06-30 11:09:59 -06:00
|
|
|
filename, tensor_name = self.get_filename(tensor_name)
|
2023-06-08 06:51:52 -06:00
|
|
|
f = self._get_handle(filename)
|
|
|
|
slice_ = f.get_slice(tensor_name)
|
|
|
|
return slice_
|
|
|
|
|
2024-10-16 01:54:50 -06:00
|
|
|
def has_tensor(self, tensor_name: str):
|
2024-07-12 04:20:12 -06:00
|
|
|
try:
|
|
|
|
self.get_filename(tensor_name)
|
|
|
|
except Exception:
|
|
|
|
return False
|
|
|
|
return True
|
|
|
|
|
2023-06-08 06:51:52 -06:00
|
|
|
def get_shape(self, tensor_name: str):
|
|
|
|
return self._get_slice(tensor_name).get_shape()
|
|
|
|
|
2024-10-24 08:36:18 -06:00
|
|
|
def get_tensor(
|
|
|
|
self, tensor_name: str, to_device: bool = True, to_dtype: bool = True
|
|
|
|
) -> torch.Tensor:
|
2023-06-23 04:40:46 -06:00
|
|
|
filename, tensor_name = self.get_filename(tensor_name)
|
2023-06-08 06:51:52 -06:00
|
|
|
f = self._get_handle(filename)
|
|
|
|
tensor = f.get_tensor(tensor_name)
|
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
|
|
|
# Special case for gptq which shouldn't convert
|
2024-05-28 03:51:31 -06:00
|
|
|
# u4 which are disguised as int32. Exl2 uses int16
|
2024-07-20 11:02:04 -06:00
|
|
|
# as well. FP8 uses torch.float8_e4m3fn
|
|
|
|
if (
|
|
|
|
tensor.dtype
|
|
|
|
not in [
|
|
|
|
torch.float8_e4m3fn,
|
|
|
|
torch.int16,
|
|
|
|
torch.int32,
|
|
|
|
torch.int64,
|
|
|
|
]
|
|
|
|
and to_dtype
|
|
|
|
):
|
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
|
|
|
tensor = tensor.to(dtype=self.dtype)
|
fit for baichuan models (#981)
As more and more people begin to use Baichuan's open-source models, the
influence of Baichuan models is growing, especially in China. Many
community members are interested in adding support for Baichuan models
to TGI. Meanwhile, Baichuan is a very open company, and in the future,
it plans to open-source more and more models, taking all this into
consideration, we would like to add support for the Baichuan model to
TGI. To do this, we need to make some changes, which we hope can be
merged into the main branch of TGI. In the future, we would be happy to
help maintain support for Baichuan models in TGI. We sincerely hope that
our pull request can be accepted. Thank you.
By the way, the changes of this time mainly for supporting Baichuan-7B.
---------
Co-authored-by: xiaoyuze <xiaoyuze@baichuan.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2023-09-08 08:51:34 -06:00
|
|
|
if to_device:
|
|
|
|
tensor = tensor.to(device=self.device)
|
2023-06-08 06:51:52 -06:00
|
|
|
return tensor
|
|
|
|
|
2024-07-22 09:51:32 -06:00
|
|
|
def get_partial_sharded(
|
|
|
|
self, tensor_name: str, dim: int, to_device=True, to_dtype=True
|
|
|
|
):
|
2023-06-23 04:40:46 -06:00
|
|
|
filename, tensor_name = self.get_filename(tensor_name)
|
fit for baichuan models (#981)
As more and more people begin to use Baichuan's open-source models, the
influence of Baichuan models is growing, especially in China. Many
community members are interested in adding support for Baichuan models
to TGI. Meanwhile, Baichuan is a very open company, and in the future,
it plans to open-source more and more models, taking all this into
consideration, we would like to add support for the Baichuan model to
TGI. To do this, we need to make some changes, which we hope can be
merged into the main branch of TGI. In the future, we would be happy to
help maintain support for Baichuan models in TGI. We sincerely hope that
our pull request can be accepted. Thank you.
By the way, the changes of this time mainly for supporting Baichuan-7B.
---------
Co-authored-by: xiaoyuze <xiaoyuze@baichuan.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2023-09-08 08:51:34 -06:00
|
|
|
f = self._get_handle(filename)
|
|
|
|
slice_ = f.get_slice(tensor_name)
|
2023-06-08 06:51:52 -06:00
|
|
|
world_size = self.process_group.size()
|
|
|
|
rank = self.process_group.rank()
|
|
|
|
|
|
|
|
size = slice_.get_shape()[dim]
|
2024-01-24 05:08:41 -07:00
|
|
|
block_size = (size + world_size - 1) // world_size
|
2023-06-08 06:51:52 -06:00
|
|
|
start = rank * block_size
|
|
|
|
stop = (rank + 1) * block_size
|
|
|
|
|
|
|
|
if dim == 0:
|
|
|
|
tensor = slice_[start:stop]
|
|
|
|
elif dim == 1:
|
|
|
|
tensor = slice_[:, start:stop]
|
|
|
|
else:
|
|
|
|
raise NotImplementedError("Let's make that generic when needed")
|
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
|
|
|
# Special case for gptq which shouldn't convert
|
2024-05-28 03:51:31 -06:00
|
|
|
# u4 which are disguised as int32. exl2 uses int16.
|
2024-07-20 11:02:04 -06:00
|
|
|
# FP8 uses torch.float8_e4m3fn.
|
|
|
|
if (
|
|
|
|
tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32)
|
|
|
|
and to_dtype
|
|
|
|
):
|
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
|
|
|
tensor = tensor.to(dtype=self.dtype)
|
2024-07-22 09:51:32 -06:00
|
|
|
if to_device:
|
|
|
|
tensor = tensor.to(device=self.device)
|
2023-06-08 06:51:52 -06:00
|
|
|
return tensor
|
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
|
|
|
|
2024-07-22 09:51:32 -06:00
|
|
|
def get_sharded(self, tensor_name: str, dim: int, to_device=True, to_dtype=True):
|
2023-07-12 08:43:31 -06:00
|
|
|
filename, tensor_name = self.get_filename(tensor_name)
|
|
|
|
f = self._get_handle(filename)
|
|
|
|
slice_ = f.get_slice(tensor_name)
|
|
|
|
world_size = self.process_group.size()
|
|
|
|
size = slice_.get_shape()[dim]
|
|
|
|
assert (
|
|
|
|
size % world_size == 0
|
|
|
|
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
2024-07-22 09:51:32 -06:00
|
|
|
return self.get_partial_sharded(
|
|
|
|
tensor_name, dim, to_device=to_device, to_dtype=to_dtype
|
|
|
|
)
|
2023-07-12 08:43:31 -06:00
|
|
|
|
2024-06-20 01:56:04 -06:00
|
|
|
def get_packed_sharded(
|
2024-07-20 11:02:04 -06:00
|
|
|
self,
|
|
|
|
tensor_name: str,
|
|
|
|
dim: int,
|
|
|
|
block_sizes: Union[int, List[int]],
|
|
|
|
to_dtype=True,
|
2024-06-20 01:56:04 -06:00
|
|
|
) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Get a shard from a tensor that packs multiple tensors.
|
|
|
|
|
|
|
|
When a tensor packs multiple tensors (such as QKV or an up
|
|
|
|
projection + gate projection), sharding with `get_sharded` is not
|
|
|
|
safe since it would not split the packed tensors across shards.
|
|
|
|
|
|
|
|
This method shards a tensor, such that the packed tensors are
|
|
|
|
split across shards.
|
|
|
|
|
|
|
|
The columns are split in equally sized blocks when blocks is an `int`, or
|
|
|
|
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
|
|
|
|
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
|
|
|
|
convenient for e.g. splitting QKV without knowing the storage details of
|
|
|
|
quantized weights.
|
|
|
|
"""
|
|
|
|
slice_ = self._get_slice(tensor_name)
|
|
|
|
total_size = slice_.get_shape()[dim]
|
2024-06-10 01:22:29 -06:00
|
|
|
block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)
|
|
|
|
|
fit for baichuan models (#981)
As more and more people begin to use Baichuan's open-source models, the
influence of Baichuan models is growing, especially in China. Many
community members are interested in adding support for Baichuan models
to TGI. Meanwhile, Baichuan is a very open company, and in the future,
it plans to open-source more and more models, taking all this into
consideration, we would like to add support for the Baichuan model to
TGI. To do this, we need to make some changes, which we hope can be
merged into the main branch of TGI. In the future, we would be happy to
help maintain support for Baichuan models in TGI. We sincerely hope that
our pull request can be accepted. Thank you.
By the way, the changes of this time mainly for supporting Baichuan-7B.
---------
Co-authored-by: xiaoyuze <xiaoyuze@baichuan.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2023-09-08 08:51:34 -06:00
|
|
|
world_size = self.process_group.size()
|
|
|
|
rank = self.process_group.rank()
|
|
|
|
|
2024-06-20 01:56:04 -06:00
|
|
|
tensors = []
|
2024-06-10 01:22:29 -06:00
|
|
|
block_offset = 0
|
|
|
|
for block_size in block_sizes:
|
|
|
|
assert (
|
|
|
|
block_size % world_size == 0
|
2024-06-20 01:56:04 -06:00
|
|
|
), f"Prepacked tensor cannot be sharded across {world_size} shards"
|
2024-06-10 01:22:29 -06:00
|
|
|
shard_block_size = block_size // world_size
|
|
|
|
start = rank * shard_block_size
|
|
|
|
stop = (rank + 1) * shard_block_size
|
2024-06-20 01:56:04 -06:00
|
|
|
if dim == 0:
|
|
|
|
tensor = slice_[block_offset + start : block_offset + stop]
|
|
|
|
elif dim == 1:
|
|
|
|
tensor = slice_[:, block_offset + start : block_offset + stop]
|
|
|
|
else:
|
|
|
|
raise NotImplementedError("Currently only dim=0 or dim=1 is supported")
|
|
|
|
tensors.append(tensor)
|
2024-06-10 01:22:29 -06:00
|
|
|
block_offset += block_size
|
2024-06-20 01:56:04 -06:00
|
|
|
tensor = torch.cat(tensors, dim=dim)
|
|
|
|
tensor = tensor.to(device=self.device)
|
2024-06-04 11:37:49 -06:00
|
|
|
|
2024-06-20 01:56:04 -06:00
|
|
|
# Avoid casting quantizer dtypes.
|
2024-07-20 11:02:04 -06:00
|
|
|
if (
|
|
|
|
tensor.dtype
|
|
|
|
not in [
|
|
|
|
torch.float8_e4m3fn,
|
|
|
|
torch.int16,
|
|
|
|
torch.int32,
|
|
|
|
torch.int64,
|
|
|
|
]
|
|
|
|
and to_dtype
|
|
|
|
):
|
2024-06-20 01:56:04 -06:00
|
|
|
tensor = tensor.to(dtype=self.dtype)
|
|
|
|
|
|
|
|
return tensor
|
fit for baichuan models (#981)
As more and more people begin to use Baichuan's open-source models, the
influence of Baichuan models is growing, especially in China. Many
community members are interested in adding support for Baichuan models
to TGI. Meanwhile, Baichuan is a very open company, and in the future,
it plans to open-source more and more models, taking all this into
consideration, we would like to add support for the Baichuan model to
TGI. To do this, we need to make some changes, which we hope can be
merged into the main branch of TGI. In the future, we would be happy to
help maintain support for Baichuan models in TGI. We sincerely hope that
our pull request can be accepted. Thank you.
By the way, the changes of this time mainly for supporting Baichuan-7B.
---------
Co-authored-by: xiaoyuze <xiaoyuze@baichuan.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2023-09-08 08:51:34 -06:00
|
|
|
|
2024-07-19 09:23:20 -06:00
|
|
|
def get_weights(self, prefix: str):
|
|
|
|
return self.weights_loader.get_weights(self, prefix)
|
|
|
|
|
2024-06-10 01:22:29 -06:00
|
|
|
def get_weights_col_packed_qkv(
|
|
|
|
self,
|
|
|
|
prefix: str,
|
|
|
|
num_heads: int,
|
|
|
|
num_key_value_heads: int,
|
|
|
|
):
|
|
|
|
return self.get_weights_col_packed(
|
2024-07-09 12:04:03 -06:00
|
|
|
prefix, [num_heads, num_key_value_heads, num_key_value_heads]
|
2024-06-10 01:22:29 -06:00
|
|
|
)
|
2024-04-23 10:40:05 -06:00
|
|
|
|
2024-07-09 12:04:03 -06:00
|
|
|
def get_weights_col_packed_gate_up(self, prefix: str):
|
|
|
|
return self.get_weights_col_packed(prefix, 2)
|
2024-04-23 10:40:05 -06:00
|
|
|
|
2024-07-09 12:04:03 -06:00
|
|
|
def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]):
|
fit for baichuan models (#981)
As more and more people begin to use Baichuan's open-source models, the
influence of Baichuan models is growing, especially in China. Many
community members are interested in adding support for Baichuan models
to TGI. Meanwhile, Baichuan is a very open company, and in the future,
it plans to open-source more and more models, taking all this into
consideration, we would like to add support for the Baichuan model to
TGI. To do this, we need to make some changes, which we hope can be
merged into the main branch of TGI. In the future, we would be happy to
help maintain support for Baichuan models in TGI. We sincerely hope that
our pull request can be accepted. Thank you.
By the way, the changes of this time mainly for supporting Baichuan-7B.
---------
Co-authored-by: xiaoyuze <xiaoyuze@baichuan.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2023-09-08 08:51:34 -06:00
|
|
|
"""
|
2024-06-10 01:22:29 -06:00
|
|
|
The columns are split in equally sized blocks when blocks is an `int`, or
|
|
|
|
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
|
|
|
|
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
|
|
|
|
convenient for e.g. splitting QKV without knowing the storage details of
|
|
|
|
quantized weights.
|
fit for baichuan models (#981)
As more and more people begin to use Baichuan's open-source models, the
influence of Baichuan models is growing, especially in China. Many
community members are interested in adding support for Baichuan models
to TGI. Meanwhile, Baichuan is a very open company, and in the future,
it plans to open-source more and more models, taking all this into
consideration, we would like to add support for the Baichuan model to
TGI. To do this, we need to make some changes, which we hope can be
merged into the main branch of TGI. In the future, we would be happy to
help maintain support for Baichuan models in TGI. We sincerely hope that
our pull request can be accepted. Thank you.
By the way, the changes of this time mainly for supporting Baichuan-7B.
---------
Co-authored-by: xiaoyuze <xiaoyuze@baichuan.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2023-09-08 08:51:34 -06:00
|
|
|
"""
|
2024-07-09 12:04:03 -06:00
|
|
|
return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes)
|
2024-06-05 02:14:40 -06:00
|
|
|
|
2024-07-09 12:04:03 -06:00
|
|
|
def get_weights_col(self, prefix: str):
|
|
|
|
return self.weights_loader.get_weights_col(self, prefix)
|
2024-06-05 02:14:40 -06:00
|
|
|
|
2024-07-09 12:04:03 -06:00
|
|
|
def get_multi_weights_col(self, prefixes: List[str], dim: int):
|
|
|
|
return self.weights_loader.get_multi_weights_col(self, prefixes, dim)
|
2023-09-27 04:22:09 -06:00
|
|
|
|
fit for baichuan models (#981)
As more and more people begin to use Baichuan's open-source models, the
influence of Baichuan models is growing, especially in China. Many
community members are interested in adding support for Baichuan models
to TGI. Meanwhile, Baichuan is a very open company, and in the future,
it plans to open-source more and more models, taking all this into
consideration, we would like to add support for the Baichuan model to
TGI. To do this, we need to make some changes, which we hope can be
merged into the main branch of TGI. In the future, we would be happy to
help maintain support for Baichuan models in TGI. We sincerely hope that
our pull request can be accepted. Thank you.
By the way, the changes of this time mainly for supporting Baichuan-7B.
---------
Co-authored-by: xiaoyuze <xiaoyuze@baichuan.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2023-09-08 08:51:34 -06:00
|
|
|
def get_tensor_shard(self, var, dim):
|
|
|
|
world_size = self.process_group.size()
|
|
|
|
rank = self.process_group.rank()
|
|
|
|
block_size = var.size()[dim] // world_size
|
|
|
|
start = rank * block_size
|
|
|
|
stop = (rank + 1) * block_size
|
|
|
|
if dim == 0:
|
|
|
|
tensor = var[start:stop]
|
|
|
|
elif dim == 1:
|
|
|
|
tensor = var[:, start:stop]
|
|
|
|
else:
|
|
|
|
raise NotImplementedError("Let's make that generic when needed")
|
|
|
|
tensor = tensor.to(dtype=self.dtype)
|
|
|
|
tensor = tensor.to(device=self.device)
|
2023-09-27 04:22:09 -06:00
|
|
|
return tensor
|
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
|
|
|
|
2024-07-09 12:04:03 -06:00
|
|
|
def get_weights_row(self, prefix: str):
|
|
|
|
return self.weights_loader.get_weights_row(self, prefix)
|
2024-06-10 01:22:29 -06:00
|
|
|
|
2024-07-15 23:58:25 -06:00
|
|
|
@contextmanager
|
|
|
|
def use_loader(self, weights_loader: WeightsLoader):
|
|
|
|
"""
|
|
|
|
This method is a context manager that can be used to use `Weights` with
|
|
|
|
a different loader for the duration of the context.
|
|
|
|
"""
|
|
|
|
|
|
|
|
old_loader = self.weights_loader
|
|
|
|
self.weights_loader = weights_loader
|
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
self.weights_loader = old_loader
|
|
|
|
|
2024-09-17 10:08:58 -06:00
|
|
|
@property
|
|
|
|
def loader(self):
|
|
|
|
return self.weights_loader
|
|
|
|
|
2024-06-10 01:22:29 -06:00
|
|
|
|
|
|
|
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
|
|
|
"""
|
|
|
|
Convert block count or proportions to block sizes.
|
|
|
|
|
|
|
|
This function accepts
|
|
|
|
|
|
|
|
- The number of blocks (int), in which case the block size is
|
|
|
|
total_size//blocks; or
|
|
|
|
- A list of block sizes (List[int]).
|
|
|
|
|
|
|
|
In the latter case, if sum(blocks) < total_size, the ratios between
|
|
|
|
the block sizes will be preserved. For instance, if blocks is
|
|
|
|
[2, 1, 1] and total_size is 1024, the returned block sizes are
|
|
|
|
[512, 256, 256].
|
|
|
|
"""
|
|
|
|
if isinstance(blocks, list):
|
|
|
|
total_blocks = sum(blocks)
|
|
|
|
assert (
|
|
|
|
total_size % total_blocks == 0
|
|
|
|
), f"Cannot split {total_size} in proportional blocks: {blocks}"
|
|
|
|
part_size = total_size // total_blocks
|
|
|
|
return [part_size * block for block in blocks]
|
|
|
|
else:
|
|
|
|
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
|
|
|
|
single_size = total_size // blocks
|
|
|
|
return [single_size] * blocks
|