hf_text-generation-inference/server/text_generation_server/utils/weights.py

350 lines
12 KiB
Python

from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, List, Optional, Union
from safetensors import safe_open
import torch
class WeightsLoader(ABC):
"""
Instances of this type implement higher-level weight loading.
At a low-level, every weight is stored in the Safetensors format.
The interpretation of weights may be different however, for instance
could be packed, quantized weights. Loaders are responsible for
interpreting the raw tensors, sharding tensors in a manner compatible
with the format, etc.
"""
@abstractmethod
def get_weights_col_packed(
self,
weights: "Weights",
prefix: str,
block_sizes: Union[int, List[int]],
):
"""
Get the packed weights at the given prefix with column-splitting for
tensor parallelism. This method should be used when multiple different
weights are packed into a tensor, for instance, query/key/value
weights or a gate/up projection.
The `block_sizes` determines the proportions of the packed tensors.
The columns are split in equally sized blocks when `block_sizes` is an
`int`, or in blocks proportional given to the sizes. For instance
`[2, 1, 1]` will divide an input with dimensionality `1024` in
`[512, 256, 256]`.
"""
...
def get_weights_col(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply column-splitting for tensor
paralllism.
"""
return weights.get_multi_weights_col([prefix], 0)
@abstractmethod
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
"""
Get the weights at the given prefixes, column-split them for tensor
parallelim, and then concatenate the weights along the given dimension.
"""
...
@abstractmethod
def get_weights_row(self, weights: "Weights", prefix: str):
"""
Get the weights at the given prefix and apply row-splitting for tensor
parallism.
"""
...
class DefaultWeightsLoader(WeightsLoader):
"""
Loader that uses tensors as-is with the exception of applying sharding
and/or concatenation.
"""
def get_weights_col_packed(
self,
weights: "Weights",
prefix: str,
block_sizes: Union[int, List[int]],
):
return weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
return torch.cat(w, dim=dim)
def get_weights_row(self, weights: "Weights", prefix: str):
return weights.get_sharded(f"{prefix}.weight", dim=1)
class Weights:
def __init__(
self,
filenames: List[Path],
device,
dtype,
process_group,
weights_loader: WeightsLoader,
aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None,
):
routing = {}
for filename in filenames:
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
if aliases is None:
aliases = {}
self.aliases = aliases
self.routing = routing
self.device = device
self.dtype = dtype
self.process_group = process_group
self.prefix = prefix
self.weights_loader = weights_loader
self._handles = {}
def _get_handle(self, filename):
if filename not in self._handles:
f = safe_open(filename, framework="pytorch")
self._handles[filename] = f
return self._handles[filename]
def get_filename(self, tensor_name: str) -> (str, str):
names = [tensor_name]
if self.prefix is not None:
prefixed = f"{self.prefix}.{tensor_name}"
names.append(prefixed)
for name in names:
filename = self.routing.get(name, None)
if filename is not None:
return str(filename), name
aliases = self.aliases.get(name, [])
for alias in aliases:
filename = self.routing.get(alias, None)
if filename is not None:
return str(filename), alias
raise RuntimeError(f"weight {tensor_name} does not exist")
def _get_slice(self, tensor_name: str):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
return slice_
def _has_tensor(self, tensor_name: str):
try:
self.get_filename(tensor_name)
except Exception:
return False
return True
def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str, to_device=True):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name)
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32. Exl2 uses int16
# as well.
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
tensor = tensor.to(dtype=self.dtype)
if to_device:
tensor = tensor.to(device=self.device)
return tensor
def get_partial_sharded(self, tensor_name: str, dim: int):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
world_size = self.process_group.size()
rank = self.process_group.rank()
size = slice_.get_shape()[dim]
block_size = (size + world_size - 1) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
if dim == 0:
tensor = slice_[start:stop]
elif dim == 1:
tensor = slice_[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32. exl2 uses int16.
if tensor.dtype not in (torch.int16, torch.int32):
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
def get_sharded(self, tensor_name: str, dim: int):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
world_size = self.process_group.size()
size = slice_.get_shape()[dim]
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)
def get_packed_sharded(
self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]]
) -> torch.Tensor:
"""
Get a shard from a tensor that packs multiple tensors.
When a tensor packs multiple tensors (such as QKV or an up
projection + gate projection), sharding with `get_sharded` is not
safe since it would not split the packed tensors across shards.
This method shards a tensor, such that the packed tensors are
split across shards.
The columns are split in equally sized blocks when blocks is an `int`, or
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
convenient for e.g. splitting QKV without knowing the storage details of
quantized weights.
"""
slice_ = self._get_slice(tensor_name)
total_size = slice_.get_shape()[dim]
block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)
world_size = self.process_group.size()
rank = self.process_group.rank()
tensors = []
block_offset = 0
for block_size in block_sizes:
assert (
block_size % world_size == 0
), f"Prepacked tensor cannot be sharded across {world_size} shards"
shard_block_size = block_size // world_size
start = rank * shard_block_size
stop = (rank + 1) * shard_block_size
if dim == 0:
tensor = slice_[block_offset + start : block_offset + stop]
elif dim == 1:
tensor = slice_[:, block_offset + start : block_offset + stop]
else:
raise NotImplementedError("Currently only dim=0 or dim=1 is supported")
tensors.append(tensor)
block_offset += block_size
tensor = torch.cat(tensors, dim=dim)
tensor = tensor.to(device=self.device)
# Avoid casting quantizer dtypes.
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
tensor = tensor.to(dtype=self.dtype)
return tensor
def get_weights_col_packed_qkv(
self,
prefix: str,
num_heads: int,
num_key_value_heads: int,
):
return self.get_weights_col_packed(
prefix, [num_heads, num_key_value_heads, num_key_value_heads]
)
def get_weights_col_packed_gate_up(self, prefix: str):
return self.get_weights_col_packed(prefix, 2)
def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]):
"""
The columns are split in equally sized blocks when blocks is an `int`, or
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
convenient for e.g. splitting QKV without knowing the storage details of
quantized weights.
"""
return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes)
def get_weights_col(self, prefix: str):
return self.weights_loader.get_weights_col(self, prefix)
def get_multi_weights_col(self, prefixes: List[str], dim: int):
return self.weights_loader.get_multi_weights_col(self, prefixes, dim)
def get_tensor_shard(self, var, dim):
world_size = self.process_group.size()
rank = self.process_group.rank()
block_size = var.size()[dim] // world_size
start = rank * block_size
stop = (rank + 1) * block_size
if dim == 0:
tensor = var[start:stop]
elif dim == 1:
tensor = var[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
def get_weights_row(self, prefix: str):
return self.weights_loader.get_weights_row(self, prefix)
@contextmanager
def use_loader(self, weights_loader: WeightsLoader):
"""
This method is a context manager that can be used to use `Weights` with
a different loader for the duration of the context.
"""
old_loader = self.weights_loader
self.weights_loader = weights_loader
try:
yield
finally:
self.weights_loader = old_loader
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
"""
Convert block count or proportions to block sizes.
This function accepts
- The number of blocks (int), in which case the block size is
total_size//blocks; or
- A list of block sizes (List[int]).
In the latter case, if sum(blocks) < total_size, the ratios between
the block sizes will be preserved. For instance, if blocks is
[2, 1, 1] and total_size is 1024, the returned block sizes are
[512, 256, 256].
"""
if isinstance(blocks, list):
total_blocks = sum(blocks)
assert (
total_size % total_blocks == 0
), f"Cannot split {total_size} in proportional blocks: {blocks}"
part_size = total_size // total_blocks
return [part_size * block for block in blocks]
else:
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
single_size = total_size // blocks
return [single_size] * blocks