hf_text-generation-inference/server/text_generation_server/layers/linear.py

125 lines
3.7 KiB
Python

import torch
from text_generation_server.utils.import_utils import SYSTEM
from torch.nn import functional as F
import os
if SYSTEM == "rocm":
ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
"true",
"1",
)
if ROCM_USE_SKINNY_GEMM:
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class FastLinear(torch.nn.Module):
def __init__(
self,
weight,
bias,
) -> None:
super().__init__()
self.weight = torch.nn.Parameter(weight, requires_grad=False)
if bias is not None:
self.bias = torch.nn.Parameter(bias, requires_grad=False)
else:
self.bias = None
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_tensor(f"{prefix}.weight")
if bias:
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return cls(weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight, self.bias)
class FastLinearROCm(torch.nn.Module):
def __init__(
self,
weight,
bias,
) -> None:
super().__init__()
self.weight = torch.nn.Parameter(weight)
if bias is not None:
self.bias = torch.nn.Parameter(bias)
else:
self.bias = None
self.cu_count = torch.cuda.get_device_properties(
device="cuda"
).multi_processor_count
self.use_skinny_gemm = (
ROCM_USE_SKINNY_GEMM
and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName
)
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_tensor(f"{prefix}.weight")
if bias:
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return cls(weight, bias)
def forward(self, inp: torch.Tensor) -> torch.Tensor:
weight = self.weight
bias = self.bias
if (
self.use_skinny_gemm
and inp.dtype == torch.float16
and inp.shape[-1] % 8 == 0
):
batched = False
inp_shape = inp.shape
if inp.dim() == 3:
inp = inp.view(-1, inp_shape[-1])
batched = True
m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
if m > 8 and n <= 4:
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
elif m % 4 == 0 and n == 1 and k <= 8192:
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
_custom_C.LLMM1(weight, inp, out, 4)
else:
out = F.linear(inp, weight)
if batched:
out.view(*inp_shape[:-1], out.shape[-1])
if bias is not None:
out = out + bias
return out
return F.linear(inp, self.weight, self.bias)
def get_linear(weight, bias):
# Weights that are loaded through methods that are not
# quantization-aware are still bare tensors. We may want
# to change this in the future.
if isinstance(weight, torch.Tensor):
if SYSTEM == "rocm":
return FastLinearROCm(weight, bias)
else:
return FastLinear(weight, bias)
return weight.get_linear(bias)