From fd89d9dfaef408e4482ba5e230473de5bb203a99 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 13 May 2024 12:44:30 +0200 Subject: [PATCH] Refactor layers. (#1866) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- server/tests/utils/test_layers.py | 2 +- .../text_generation_server/layers/__init__.py | 14 + .../{utils => layers}/awq/conversion_utils.py | 0 .../{utils => layers}/awq/quantize/qmodule.py | 0 server/text_generation_server/layers/bnb.py | 106 ++ server/text_generation_server/layers/conv.py | 41 + server/text_generation_server/layers/eetq.py | 25 + server/text_generation_server/layers/fp8.py | 43 + .../layers/gptq/__init__.py | 39 + .../{utils => layers}/gptq/custom_autotune.py | 0 .../{utils => layers}/gptq/exllama.py | 0 .../{utils => layers}/gptq/exllamav2.py | 2 + .../layers/gptq/quant_linear.py | 356 +++++ .../{utils => layers}/gptq/quantize.py | 0 .../layers/layernorm.py | 185 +++ .../text_generation_server/layers/linear.py | 153 ++ .../text_generation_server/layers/medusa.py | 186 +++ .../text_generation_server/layers/rotary.py | 419 +++++ .../layers/speculative.py | 35 + .../layers/tensor_parallel.py | 188 +++ .../models/cache_manager.py | 4 +- .../models/custom_modeling/bloom_modeling.py | 2 +- .../models/custom_modeling/clip.py | 2 +- .../custom_modeling/flash_cohere_modeling.py | 18 +- .../custom_modeling/flash_dbrx_modeling.py | 19 +- .../custom_modeling/flash_gemma_modeling.py | 6 +- .../custom_modeling/flash_llama_modeling.py | 6 +- .../custom_modeling/flash_mistral_modeling.py | 6 +- .../custom_modeling/flash_mixtral_modeling.py | 14 +- .../custom_modeling/flash_neox_modeling.py | 10 +- .../custom_modeling/flash_phi_modeling.py | 8 +- .../custom_modeling/flash_qwen2_modeling.py | 6 +- .../custom_modeling/flash_rw_modeling.py | 10 +- .../flash_santacoder_modeling.py | 10 +- .../flash_starcoder2_modeling.py | 10 +- .../models/custom_modeling/idefics2.py | 2 +- .../custom_modeling/idefics_modeling.py | 16 +- .../custom_modeling/idefics_perceiver.py | 2 +- .../models/custom_modeling/idefics_vision.py | 2 +- .../models/custom_modeling/llava_next.py | 2 +- .../models/custom_modeling/mamba_modeling.py | 4 +- .../models/custom_modeling/mpt_modeling.py | 2 +- .../models/custom_modeling/neox_modeling.py | 2 +- .../models/custom_modeling/opt_modeling.py | 2 +- .../models/custom_modeling/phi_modeling.py | 2 +- .../models/custom_modeling/t5_modeling.py | 2 +- .../models/flash_causal_lm.py | 36 +- .../models/flash_llama.py | 4 +- .../models/flash_mistral.py | 4 +- .../models/flash_neox.py | 4 +- .../text_generation_server/models/flash_rw.py | 4 +- .../models/flash_santacoder.py | 4 +- server/text_generation_server/server.py | 2 +- .../utils/flash_attn.py | 197 ++- .../utils/gptq/quant_linear.py | 359 ----- .../utils/import_utils.py | 41 +- server/text_generation_server/utils/layers.py | 1374 ----------------- .../utils/paged_attention.py | 24 +- .../text_generation_server/utils/weights.py | 10 +- 59 files changed, 2085 insertions(+), 1941 deletions(-) create mode 100644 server/text_generation_server/layers/__init__.py rename server/text_generation_server/{utils => layers}/awq/conversion_utils.py (100%) rename server/text_generation_server/{utils => layers}/awq/quantize/qmodule.py (100%) create mode 100644 server/text_generation_server/layers/bnb.py create mode 100644 server/text_generation_server/layers/conv.py create mode 100644 server/text_generation_server/layers/eetq.py create mode 100644 server/text_generation_server/layers/fp8.py create mode 100644 server/text_generation_server/layers/gptq/__init__.py rename server/text_generation_server/{utils => layers}/gptq/custom_autotune.py (100%) rename server/text_generation_server/{utils => layers}/gptq/exllama.py (100%) rename server/text_generation_server/{utils => layers}/gptq/exllamav2.py (96%) create mode 100644 server/text_generation_server/layers/gptq/quant_linear.py rename server/text_generation_server/{utils => layers}/gptq/quantize.py (100%) create mode 100644 server/text_generation_server/layers/layernorm.py create mode 100644 server/text_generation_server/layers/linear.py create mode 100644 server/text_generation_server/layers/medusa.py create mode 100644 server/text_generation_server/layers/rotary.py create mode 100644 server/text_generation_server/layers/speculative.py create mode 100644 server/text_generation_server/layers/tensor_parallel.py delete mode 100644 server/text_generation_server/utils/gptq/quant_linear.py delete mode 100644 server/text_generation_server/utils/layers.py diff --git a/server/tests/utils/test_layers.py b/server/tests/utils/test_layers.py index 93a0e982..9a8da0d6 100644 --- a/server/tests/utils/test_layers.py +++ b/server/tests/utils/test_layers.py @@ -1,5 +1,5 @@ import torch -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelEmbedding, ) diff --git a/server/text_generation_server/layers/__init__.py b/server/text_generation_server/layers/__init__.py new file mode 100644 index 00000000..c3a6c921 --- /dev/null +++ b/server/text_generation_server/layers/__init__.py @@ -0,0 +1,14 @@ +from text_generation_server.layers.tensor_parallel import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, + TensorParallelEmbedding, +) +from text_generation_server.layers.speculative import SpeculativeHead +from text_generation_server.layers.linear import ( + get_linear, + FastLinear, +) + +# Just to add the `load` methods. +from text_generation_server.layers.layernorm import load_layer_norm +from text_generation_server.layers.conv import load_conv2d diff --git a/server/text_generation_server/utils/awq/conversion_utils.py b/server/text_generation_server/layers/awq/conversion_utils.py similarity index 100% rename from server/text_generation_server/utils/awq/conversion_utils.py rename to server/text_generation_server/layers/awq/conversion_utils.py diff --git a/server/text_generation_server/utils/awq/quantize/qmodule.py b/server/text_generation_server/layers/awq/quantize/qmodule.py similarity index 100% rename from server/text_generation_server/utils/awq/quantize/qmodule.py rename to server/text_generation_server/layers/awq/quantize/qmodule.py diff --git a/server/text_generation_server/layers/bnb.py b/server/text_generation_server/layers/bnb.py new file mode 100644 index 00000000..d27a33a1 --- /dev/null +++ b/server/text_generation_server/layers/bnb.py @@ -0,0 +1,106 @@ +import torch +from loguru import logger +from functools import lru_cache +import bitsandbytes as bnb +from bitsandbytes.nn import Int8Params, Params4bit + + +@lru_cache(1) +def warn_deprecate_bnb(): + logger.warning( + "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce" + ) + + +class Linear8bitLt(torch.nn.Module): + def __init__( + self, + weight, + bias, + has_fp16_weights=True, + memory_efficient_backward=False, + threshold=0.0, + index=None, + ): + super().__init__() + assert ( + not memory_efficient_backward + ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" + self.state = bnb.MatmulLtState() + self.index = index + + # Necessary for stacked layers + self.state.threshold = threshold + self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward + if threshold > 0.0 and not has_fp16_weights: + self.state.use_pool = True + + self.weight = Int8Params( + weight.data, + has_fp16_weights=has_fp16_weights, + requires_grad=has_fp16_weights, + ) + self.weight.cuda(weight.device) + self.bias = bias + + def init_8bit_state(self): + self.state.CB = self.weight.CB + self.state.SCB = self.weight.SCB + self.weight.CB = None + self.weight.SCB = None + + def forward(self, x: torch.Tensor): + self.state.is_training = self.training + if self.weight.CB is not None: + self.init_8bit_state() + + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) + + if not self.state.has_fp16_weights: + if self.state.CB is not None and self.state.CxB is not None: + # we converted 8-bit row major to turing/ampere format in the first inference pass + # we no longer need the row-major weight + del self.state.CB + self.weight.data = self.state.CxB + return out + + +class Linear4bit(nn.Module): + def __init__(self, weight, bias, quant_type): + super().__init__() + self.weight = Params4bit( + weight.data, + requires_grad=False, + compress_statistics=True, + quant_type=quant_type, + ) + self.compute_dtype = None + self.weight.cuda(weight.device) + self.bias = bias + + def forward(self, x: torch.Tensor): + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + if getattr(self.weight, "quant_state", None) is None: + print( + "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first." + ) + inp_dtype = x.dtype + if self.compute_dtype is not None: + x = x.to(self.compute_dtype) + + bias = None if self.bias is None else self.bias.to(self.compute_dtype) + out = bnb.matmul_4bit( + x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state + ) + + out = out.to(inp_dtype) + + return out diff --git a/server/text_generation_server/layers/conv.py b/server/text_generation_server/layers/conv.py new file mode 100644 index 00000000..7fb18ab3 --- /dev/null +++ b/server/text_generation_server/layers/conv.py @@ -0,0 +1,41 @@ +from accelerate import init_empty_weights +import torch + + +@classmethod +def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride): + weight = weights.get_tensor(f"{prefix}.weight") + bias = weights.get_tensor(f"{prefix}.bias") + with init_empty_weights(): + conv2d = cls( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + ) + + conv2d.weight = torch.nn.Parameter(weight) + conv2d.bias = torch.nn.Parameter(bias) + return conv2d + + +@classmethod +def load_conv2d_no_bias( + cls, prefix, weights, in_channels, out_channels, kernel_size, stride +): + weight = weights.get_tensor(f"{prefix}.weight") + with init_empty_weights(): + conv2d = cls( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + ) + + conv2d.weight = torch.nn.Parameter(weight) + conv2d.bias = None + return conv2d + + +torch.nn.Conv2d.load = load_conv2d +torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias diff --git a/server/text_generation_server/layers/eetq.py b/server/text_generation_server/layers/eetq.py new file mode 100644 index 00000000..fd22b5c6 --- /dev/null +++ b/server/text_generation_server/layers/eetq.py @@ -0,0 +1,25 @@ +import torch +from EETQ import quant_weights, w8_a16_gemm + + +class EETQLinear(torch.nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + device = weight.device + if weight.dtype != torch.float16: + weight = weight.to(dtype=torch.float16) + weight = torch.t(weight).contiguous().cpu() + weight, scale = quant_weights(weight, torch.int8, False) + + self.weight = weight.cuda(device) + self.scale = scale.cuda(device) + self.bias = bias.cuda(device) if bias is not None else None + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = w8_a16_gemm(input, self.weight, self.scale) + output = output + self.bias if self.bias is not None else output + return output diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py new file mode 100644 index 00000000..dd61d081 --- /dev/null +++ b/server/text_generation_server/layers/fp8.py @@ -0,0 +1,43 @@ +import torch + + +def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): + device = weight.device + # weight, scale = quant_weights(weight, torch.int8, False) + finfo = torch.finfo(qdtype) + # Calculate the scale as dtype max divided by absmax + scale = finfo.max / weight.abs().max().clamp(min=1e-12) + # scale and clamp the tensor to bring it to + # the representative range of float8 data type + # (as default cast is unsaturated) + qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) + # Return both float8 data and the inverse scale (as float), + # as both required as inputs to torch._scaled_mm + qweight = qweight.to(qdtype) + scale = scale.float().reciprocal() + return qweight, scale + + +class Fp8Linear(torch.nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + self.dtype = weight.dtype + self.qweight, self.scale = fp8_quantize(weight) + + self.bias = bias if bias is not None else None + + def forward(self, input: torch.Tensor) -> torch.Tensor: + qinput, scale = fp8_quantize(input) + output, _ = torch._scaled_mm( + qinput, + self.qweight.t(), + out_dtype=self.dtype, + scale_a=scale, + scale_b=self.scale, + bias=self.bias, + ) + return output diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py new file mode 100644 index 00000000..1c46f493 --- /dev/null +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -0,0 +1,39 @@ +import os +import torch +from text_generation_server.utils.import_utils import ( + SYSTEM, +) + +try: + major, _minor = torch.cuda.get_device_capability() +except Exception: + major = 1 + +HAS_EXLLAMA = False +CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm" +V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" +if os.getenv("DISABLE_EXLLAMA") == "True": + HAS_EXLLAMA = False +elif CAN_EXLLAMA: + try: + if V2: + from text_generation_server.layers.gptq.exllamav2 import ( + QuantLinear as ExllamaQuantLinear, + create_exllama_buffers, + set_device, + ) + + HAS_EXLLAMA = "2" + else: + from text_generation_server.layers.gptq.exllama import ( + Ex4bitLinear as ExllamaQuantLinear, + create_exllama_buffers, + set_device, + ) + + HAS_EXLLAMA = "1" + + except ImportError: + pass + +from text_generation_server.layers.gptq.quant_linear import QuantLinear diff --git a/server/text_generation_server/utils/gptq/custom_autotune.py b/server/text_generation_server/layers/gptq/custom_autotune.py similarity index 100% rename from server/text_generation_server/utils/gptq/custom_autotune.py rename to server/text_generation_server/layers/gptq/custom_autotune.py diff --git a/server/text_generation_server/utils/gptq/exllama.py b/server/text_generation_server/layers/gptq/exllama.py similarity index 100% rename from server/text_generation_server/utils/gptq/exllama.py rename to server/text_generation_server/layers/gptq/exllama.py diff --git a/server/text_generation_server/utils/gptq/exllamav2.py b/server/text_generation_server/layers/gptq/exllamav2.py similarity index 96% rename from server/text_generation_server/utils/gptq/exllamav2.py rename to server/text_generation_server/layers/gptq/exllamav2.py index 80836a95..321ced97 100644 --- a/server/text_generation_server/utils/gptq/exllamav2.py +++ b/server/text_generation_server/layers/gptq/exllamav2.py @@ -119,6 +119,8 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): none_tensor, temp_dq, ) + else: + RuntimeError("Cannot create handle") DEVICE = None diff --git a/server/text_generation_server/layers/gptq/quant_linear.py b/server/text_generation_server/layers/gptq/quant_linear.py new file mode 100644 index 00000000..f60758b6 --- /dev/null +++ b/server/text_generation_server/layers/gptq/quant_linear.py @@ -0,0 +1,356 @@ +import math +import numpy as np +import torch +import torch.nn as nn +from torch.cuda.amp import custom_fwd + +import triton +import triton.language as tl +from . import custom_autotune + + +# code based https://github.com/fpgaminer/GPTQ-triton +@custom_autotune.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=4, + ), + ], + key=["M", "N", "K"], + nearest_power_of_two=True, + prune_configs_by={ + "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, + "perf_model": None, + "top_k": None, + }, +) +@triton.jit +def matmul_248_kernel( + a_ptr, + b_ptr, + c_ptr, + scales_ptr, + zeros_ptr, + g_ptr, + M, + N, + K, + bits, + maxq, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_scales, + stride_zeros, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = offs_am[:, None] < M + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + ( + (offs_k[:, None] // infearure_per_bits) * stride_bk + + offs_bn[None, :] * stride_bn + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + offs_bn[None, :] + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, num_pid_k): + g_idx = tl.load(g_ptrs) + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load( + scales_ptrs + g_idx[:, None] * stride_scales + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load( + zeros_ptrs + g_idx[:, None] * stride_zeros + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) & maxq # eventually avoid overflow + + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_ptrs += BLOCK_SIZE_K + + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): + with torch.cuda.device(input.device): + output = torch.empty( + (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 + ) + grid = lambda META: ( + triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) + * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), + ) + matmul_248_kernel[grid]( + input, + qweight, + output, + scales, + qzeros, + g_idx, + input.shape[0], + qweight.shape[1], + input.shape[1], + bits, + maxq, + input.stride(0), + input.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + qzeros.stride(0), + ) + return output + + +class QuantLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): + output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq) + return output + + +class QuantLinear(nn.Module): + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + super().__init__() + self.register_buffer("qweight", qweight) + self.register_buffer("qzeros", qzeros) + self.register_buffer("scales", scales) + self.register_buffer("g_idx", g_idx) + if bias is not None: + self.register_buffer("bias", bias) + else: + self.bias = None + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize + + self.outfeatures = qweight.shape[1] + self.infeatures = qweight.shape[0] * 32 // bits + + @classmethod + def new(cls, bits, groupsize, infeatures, outfeatures, bias): + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) + qzeros = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), + dtype=torch.int32, + ) + scales = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 + ) + g_idx = torch.tensor( + [i // groupsize for i in range(infeatures)], dtype=torch.int32 + ) + if bias: + bias = torch.zeros((outfeatures), dtype=torch.float16) + else: + bias = None + return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + + def pack(self, linear, scales, zeros, g_idx=None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) + / self.scales[self.g_idx[idx]] + ).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 + ) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros( + (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 + ) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures,) + out = QuantLinearFunction.apply( + x.reshape(-1, x.shape[-1]), + self.qweight, + self.scales, + self.qzeros, + self.g_idx, + self.bits, + self.maxq, + ) + out = out + self.bias if self.bias is not None else out + return out.reshape(out_shape) diff --git a/server/text_generation_server/utils/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py similarity index 100% rename from server/text_generation_server/utils/gptq/quantize.py rename to server/text_generation_server/layers/gptq/quantize.py diff --git a/server/text_generation_server/layers/layernorm.py b/server/text_generation_server/layers/layernorm.py new file mode 100644 index 00000000..15d24e80 --- /dev/null +++ b/server/text_generation_server/layers/layernorm.py @@ -0,0 +1,185 @@ +import torch +from torch import nn +from accelerate import init_empty_weights +from text_generation_server.utils.import_utils import ( + SYSTEM, +) + + +# Monkey patching +@classmethod +def load_layer_norm(cls, prefix, weights, eps): + weight = weights.get_tensor(f"{prefix}.weight") + bias = weights.get_tensor(f"{prefix}.bias") + with init_empty_weights(): + ln = cls(weight.shape, eps=eps) + + ln.weight = torch.nn.Parameter(weight) + ln.bias = torch.nn.Parameter(bias) + return ln + + +@classmethod +def load_layer_norm_no_bias(cls, prefix, weights, eps): + weight = weights.get_tensor(f"{prefix}.weight") + with init_empty_weights(): + ln = cls(weight.shape, eps=eps) + + ln.weight = torch.nn.Parameter(weight) + ln.bias = None + return ln + + +torch.nn.LayerNorm.load = load_layer_norm +torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias + +if SYSTEM == "cuda": + import dropout_layer_norm + + class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + return super(FastLayerNorm, self).forward(hidden_states), residual + else: + ( + normed_hidden_states, + residual, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.eps, + 1.0, + 0, + None, + False, + False, + ) + if residual is None: + residual = hidden_states + + return normed_hidden_states, residual + +elif SYSTEM == "rocm": + from vllm import layernorm_ops + + class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if residual is not None: + hidden_states += residual + residual = hidden_states + + return super().forward(hidden_states), residual + +elif SYSTEM == "xpu": + import intel_extension_for_pytorch as ipex + + class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + res_out = hidden_states + out = ipex.llm.functional.add_layer_norm( + residual, hidden_states, self.weight, self.bias, self.eps, True + ) + if residual is not None: + res_out = residual + return out, res_out + + +class FastRMSNorm(nn.Module): + def __init__(self, weight: torch.Tensor, eps: float): + super().__init__() + + self.weight = nn.Parameter(weight) + self.variance_epsilon = eps + + @classmethod + def load(cls, prefix, weights, eps=1e-6): + weight = weights.get_tensor(f"{prefix}.weight") + return cls(weight, eps) + + def forward(self, hidden_states, residual=None): + if SYSTEM == "xpu": + residual_out = hidden_states + out = ipex.llm.functional.add_rms_norm( + residual, + hidden_states, + self.weight, + None, + self.variance_epsilon, + True, + ) + if residual is not None: + residual_out = residual + return out, residual_out + elif hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states, residual + elif SYSTEM == "cuda": + # faster post attention rms norm + ( + normed_hidden_states, + res, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + None, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + True, # Activate RMSNorm + ) + if res is None: + res = hidden_states + + return normed_hidden_states, res + elif SYSTEM == "rocm": + # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. + if residual is not None: + hidden_states += residual + residual = hidden_states + + out = torch.empty_like(hidden_states) + layernorm_ops.rms_norm( + out, + hidden_states, + self.weight.data, + self.variance_epsilon, + ) + return out, residual + else: + raise ValueError( + "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." + ) diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py new file mode 100644 index 00000000..d137a500 --- /dev/null +++ b/server/text_generation_server/layers/linear.py @@ -0,0 +1,153 @@ +import torch +from torch.nn import functional as F +from text_generation_server.utils.import_utils import SYSTEM + + +class FastLinear(torch.nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + self.weight = torch.nn.Parameter(weight) + if bias is not None: + self.bias = torch.nn.Parameter(bias) + else: + self.bias = None + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_tensor(f"{prefix}.weight") + if bias: + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return cls(weight, bias) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight, self.bias) + + +def get_linear(weight, bias, quantize): + if quantize is None: + linear = FastLinear(weight, bias) + elif quantize == "eetq": + try: + from text_generation_server.layers.eetq import EETQLinear + + linear = EETQLinear(weight, bias) + except ImportError: + raise ImportError( + "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" + ) + elif quantize == "fp8": + from text_generation_server.layers.fp8 import Fp8Linear + + linear = Fp8Linear(weight, bias) + elif quantize == "bitsandbytes": + try: + from text_generation_server.layers.bnb import ( + warn_deprecate_bnb, + Linear8bitLt, + ) + except ImportError: + raise NotImplementedError( + f"Bitsandbytes is missing install it with `pip install bitsandbytes`." + ) + warn_deprecate_bnb() + linear = Linear8bitLt( + weight, + bias, + has_fp16_weights=False, + threshold=6.0, + ) + if bias is not None: + linear.bias = nn.Parameter(bias) + elif quantize == "bitsandbytes-fp4": + try: + from text_generation_server.layers.bnb import Linear4bit + except ImportError: + raise NotImplementedError( + f"Bitsandbytes is missing install it with `pip install bitsandbytes`." + ) + linear = Linear4bit( + weight, + bias, + quant_type="fp4", + ) + elif quantize == "bitsandbytes-nf4": + try: + from text_generation_server.layers.bnb import Linear4bit + except ImportError: + raise NotImplementedError( + f"Bitsandbytes is missing install it with `pip install bitsandbytes`." + ) + linear = Linear4bit( + weight, + bias, + quant_type="nf4", + ) + elif quantize == "gptq": + try: + qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight + except Exception: + raise NotImplementedError( + f"The passed weight is not `gptq` compatible, loader needs to be updated." + ) + + if use_exllama: + try: + from text_generation_server.layers.gptq import ( + ExllamaQuantLinear, + ) + except ImportError: + raise NotImplementedError( + f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" + ) + + linear = ExllamaQuantLinear( + qweight, qzeros, scales, g_idx, bias, bits, groupsize + ) + else: + from text_generation_server.layers.gptq.quant_linear import QuantLinear + + linear = QuantLinear( + qweight, + qzeros, + scales, + g_idx, + bias, + bits, + groupsize, + ) + elif quantize == "awq": + try: + qweight, qzeros, scales, _, bits, groupsize, _ = weight + except Exception: + raise NotImplementedError( + f"The passed weight is not `awq` compatible, loader needs to be updated." + ) + if SYSTEM == "rocm": + raise NotImplementedError( + "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " + "to use Exllama/GPTQ kernels for AWQ inference." + ) + try: + from text_generation_server.layers.awq.quantize.qmodule import WQLinear + + linear = WQLinear( + w_bit=bits, + group_size=groupsize, + qweight=qweight, + qzeros=qzeros, + scales=scales, + bias=bias is not None, + ) + except ImportError: + raise NotImplementedError( + "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" + ) + else: + raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") + return linear diff --git a/server/text_generation_server/layers/medusa.py b/server/text_generation_server/layers/medusa.py new file mode 100644 index 00000000..4ac86978 --- /dev/null +++ b/server/text_generation_server/layers/medusa.py @@ -0,0 +1,186 @@ +import torch +from torch import nn +from typing import Tuple, Optional +from text_generation_server.utils.speculate import get_speculate +from text_generation_server.layers.linear import FastLinear +from text_generation_server.layers.tensor_parallel import ( + TensorParallelHead, + TensorParallelColumnLinear, +) + + +class ResBlock(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.linear = FastLinear.load( + config, prefix=f"{prefix}.linear", weights=weights, bias=True + ) + self.act = torch.nn.SiLU() + + def forward(self, x): + return x + self.act(self.linear(x)) + + +class MedusaModel(torch.nn.Module): + def __init__(self, config, medusa_config, weights): + super().__init__() + self.heads = torch.nn.ModuleList( + [ + MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights) + for i in range(get_speculate()) + ] + ) + + def forward(self, x): + speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) + return speculative_logits + + +class MedusaHead(torch.nn.Module): + def __init__(self, config, medusa_config, prefix, weights): + super().__init__() + self.blocks = torch.nn.ModuleList( + [ + ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) + for i in range(medusa_config["medusa_num_layers"]) + ] + ) + n = len(self.blocks) + self.out = FastLinear.load( + config, prefix=f"{prefix}.{n}", weights=weights, bias=False + ) + + def forward(self, x): + for block in self.blocks: + x = block(x) + x = self.out(x) + return x + + +class MedusaHeadV1(nn.Module): + def __init__(self, lm_head, medusa): + super().__init__() + self.lm_head = lm_head + self.medusa = medusa + + @staticmethod + def load(config, prefix: str, weights): + from pathlib import Path + from safetensors import safe_open + import json + + use_medusa = config.use_medusa + + medusa_config = str(Path(use_medusa) / "config.json") + filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") + + with open(medusa_config, "r") as f: + medusa_config = json.load(f) + routing = weights.routing + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing and routing[k] != filename: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename + + medusa = MedusaModel(config, medusa_config, weights) + lm_head = TensorParallelHead.load(config, prefix, weights) + return MedusaHeadV1(lm_head, medusa) + + def forward( + self, input: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + logits = self.lm_head(input) + # If we have too many tokens, we skip speculative logits + if input.shape[0] > 128: + return logits, None + + speculative_logits = self.medusa(input) + return logits, speculative_logits + + +class MedusaHeadV2(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + from pathlib import Path + from safetensors import safe_open + import json + + use_medusa = config.use_medusa + + medusa_config = str(Path(use_medusa) / "config.json") + filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") + + with open(medusa_config, "r") as f: + medusa_config = json.load(f) + routing = weights.routing + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing and routing[k] != filename: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename + + self.n_medusa_heads = get_speculate() + + assert medusa_config["medusa_num_layers"] == 1 + self.linear = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)], + dim=0, + weights=weights, + bias=True, + ) + self.process_group = weights.process_group + self.world_size = self.process_group.size() + self.rank = self.process_group.rank() + + self.act = torch.nn.SiLU() + + self.lm_head = TensorParallelHead.load(config, prefix, weights) + + def forward(self, x): + # If we have too many tokens, we skip speculative logits + if x.shape[0] > 128: + logits = self.lm_head(x) + return logits, None + + size = x.shape[-1] + block_size = (size + self.world_size - 1) // self.world_size + start = self.rank * block_size + stop = (self.rank + 1) * block_size + + x_block = x[:, start:stop] + + # Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1 + medusa_res = self.act(self.linear(x)).reshape( + *x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1] + ) + + # Apply all residual medusa heads + output = x[:, start:stop].unsqueeze(-2) + medusa_res + + # Gather medusa heads + world_output = [ + torch.empty_like(output) for _ in range(self.process_group.size()) + ] + torch.distributed.all_gather(world_output, output, group=self.process_group) + world_output = torch.cat(world_output, dim=-1) + + # Stack x and medusa residual x + stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2) + + # Compute lm head on x + medusa residual x + logits = self.lm_head(stacked_x) + + # Finally, split logits from speculative logits + logits, speculative_logits = torch.split( + logits, [1, self.n_medusa_heads], dim=-2 + ) + # Squeeze added dimension + logits = logits.squeeze(-2) + + return logits, speculative_logits diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py new file mode 100644 index 00000000..503dd554 --- /dev/null +++ b/server/text_generation_server/layers/rotary.py @@ -0,0 +1,419 @@ +import os +import torch +from torch import nn + +from text_generation_server.utils.import_utils import SYSTEM + +if SYSTEM == "cuda": + from flash_attn.layers.rotary import RotaryEmbedding + import rotary_emb +elif SYSTEM == "rocm": + from vllm import pos_encoding_ops + + +def _create_inv_freq(dim, base, device): + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) + ) + return inv_freq + + +def _get_rope_config(config): + if os.getenv("ROPE_SCALING", None) is not None: + rope_scaling = { + "type": os.environ["ROPE_SCALING"], + "factor": float(os.environ["ROPE_FACTOR"]), + } + return rope_scaling + return getattr(config, "rope_scaling", None) + + +class PositionRotaryEmbedding(nn.Module): + def __init__(self, inv_freq, scaling_factor): + super().__init__() + self.inv_freq = inv_freq + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + self.scaling_factor = scaling_factor + self.dynamic_args = None + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): + # Such controlflows may add some overhead. + if SYSTEM == "cuda": + rotary_dim = cos.shape[-1] + q1 = query[..., :rotary_dim] + q2 = query[..., rotary_dim : 2 * rotary_dim] + + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + + k1 = key[..., :rotary_dim] + k2 = key[..., rotary_dim : 2 * rotary_dim] + + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + elif SYSTEM == "rocm": + # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. + # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 + + head_size = query.shape[-1] + + # Inplace operation, updating query and key. + pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True) + elif SYSTEM == "xpu": + ipex.llm.functional.rotary_embedding( + query, key, sin, cos, query.size(-1), True + ) + else: + raise ValueError( + "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." + ) + + @classmethod + def static(cls, config, dim, base, device): + inv_freq = _create_inv_freq(dim, base, device) + scaling_factor = None + rope_scaling = _get_rope_config(config) + if rope_scaling is not None: + if rope_scaling["type"] == "linear": + pass + elif rope_scaling["type"] == "dynamic": + scaling_factor = rope_scaling["factor"] + return DynamicPositionRotaryEmbedding( + dim=dim, + max_position_embeddings=config.max_position_embeddings, + base=base, + device=inv_freq.device, + scaling_factor=scaling_factor, + ) + elif rope_scaling["type"] == "yarn": + scaling_factor = rope_scaling["factor"] + return YarnPositionRotaryEmbedding( + dim=2 * inv_freq.shape[0], + max_position_embeddings=rope_scaling[ + "original_max_position_embeddings" + ], + base=10000.0, + device=inv_freq.device, + scaling_factor=scaling_factor, + extrapolation_factor=1, + attn_factor=1, + beta_fast=32, + beta_slow=1, + ) + elif rope_scaling["type"] == "su": + short_factor = torch.tensor( + rope_scaling["short_factor"], dtype=torch.float32, device=device + ) + short_inv_freq = 1.0 / ( + short_factor + * base + ** ( + torch.arange(0, dim, 2, device=device, dtype=torch.float32) + / dim + ) + ) + long_factor = torch.tensor( + rope_scaling["long_factor"], dtype=torch.float32, device=device + ) + long_inv_freq = 1.0 / ( + long_factor + * base + ** ( + torch.arange(0, dim, 2, device=device, dtype=torch.float32) + / dim + ) + ) + + original_max_position_embeddings = ( + config.original_max_position_embeddings + ) + max_position_embeddings = config.max_position_embeddings + if max_position_embeddings <= original_max_position_embeddings: + scaling_factor = 1.0 + else: + scale = max_position_embeddings / original_max_position_embeddings + scaling_factor = math.sqrt( + 1 + math.log(scale) / math.log(original_max_position_embeddings) + ) + + return SuRotaryEmbedding( + short_inv_freq=short_inv_freq, + long_inv_freq=long_inv_freq, + scaling_factor=scaling_factor, + original_max_position_embeddings=original_max_position_embeddings, + ) + else: + raise NotImplementedError( + f"rope scaling type {rope_scaling['type']} is not implemented or invalid" + ) + return cls(inv_freq, scaling_factor) + + @classmethod + def load(cls, config, prefix, weights): + # XXX: Always load this in float32 ! + dtype = weights.dtype + weights.dtype = torch.float32 + inv_freq = weights.get_tensor(f"{prefix}.inv_freq") + weights.dtype = dtype + + scaling_factor = None + rope_scaling = _get_rope_config(config) + if rope_scaling is not None: + scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "linear": + pass + elif rope_scaling["type"] == "dynamic": + return DynamicPositionRotaryEmbedding( + dim=2 * inv_freq.shape[0], + max_position_embeddings=config.max_position_embeddings, + base=10000.0, + device=inv_freq.device, + scaling_factor=scaling_factor, + ) + elif rope_scaling["type"] == "yarn": + return YarnPositionRotaryEmbedding( + dim=2 * inv_freq.shape[0], + max_position_embeddings=rope_scaling[ + "original_max_position_embeddings" + ], + base=10000.0, + device=inv_freq.device, + scaling_factor=scaling_factor, + extrapolation_factor=1, + attn_factor=1, + beta_fast=32, + beta_slow=1, + ) + else: + raise NotImplementedError( + f"rope scaling type {rope_scaling['type']} is not implemented or invalid" + ) + return cls(inv_freq, scaling_factor) + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + if self.scaling_factor is not None: + t /= self.scaling_factor + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): + """ + Return cos and sin for the asked position ids + """ + if SYSTEM == "rocm": + # For RoCm, we always use float cos/sin to avoid a cast. + # For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26 + # But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal. + dtype = torch.float32 + + self._update_cos_sin_cache(dtype, position_ids.device, max_s) + + cos = torch.index_select(self._cos_cached, 0, position_ids) + sin = torch.index_select(self._sin_cached, 0, position_ids) + + # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. + return cos.unsqueeze(1), sin.unsqueeze(1) + + +class SuRotaryEmbedding(PositionRotaryEmbedding): + def __init__( + self, + short_inv_freq, + long_inv_freq, + scaling_factor, + original_max_position_embeddings, + ): + super(PositionRotaryEmbedding, self).__init__() + self.short_inv_freq = short_inv_freq + self.long_inv_freq = long_inv_freq + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + self.dynamic_args = None + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + if seqlen > self.original_max_position_embeddings: + inv_freq = self.long_inv_freq + else: + inv_freq = self.short_inv_freq + t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype) + if self.scaling_factor is not None: + t /= self.scaling_factor + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + +class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): + def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): + inv_freq = _create_inv_freq(dim, base, device) + super().__init__(inv_freq, scaling_factor) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + if seqlen > self.max_position_embeddings: + newbase = self.base * ( + (self.scaling_factor * seqlen / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + self.inv_freq = _create_inv_freq( + self.dim, newbase, self.inv_freq.device + ) + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + +# Inverse dim formula to find dim based on number of rotations +import math + + +def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def get_mscale(scale=1): + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + +class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings, + base, + device, + scaling_factor, + *, + extrapolation_factor, + attn_factor, + beta_fast, + beta_slow, + ): + inv_freq = _create_inv_freq(dim, base, device) + super().__init__(inv_freq, scaling_factor) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = float( + get_mscale(self.scaling_factor) * self.attn_factor + ) # Get n-d magnitude scaling corrected for interpolation + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + if seqlen > self.max_position_embeddings: + inv_freq_extrapolation = _create_inv_freq( + self.dim, self.base, self.inv_freq.device + ) + freqs = 1.0 / inv_freq_extrapolation + inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs) + low, high = find_correction_range( + self.beta_fast, + self.beta_slow, + self.dim, + self.base, + self.max_position_embeddings, + ) + inv_freq_mask = ( + 1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device) + ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) + + self.inv_freq = inv_freq + self.mscale = float( + get_mscale(self.scaling_factor) * self.attn_factor + ) # Get n-d magnitude scaling corrected for interpolation + + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype) + self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) diff --git a/server/text_generation_server/layers/speculative.py b/server/text_generation_server/layers/speculative.py new file mode 100644 index 00000000..663f8c2e --- /dev/null +++ b/server/text_generation_server/layers/speculative.py @@ -0,0 +1,35 @@ +import torch +from typing import Tuple, Optional +from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2 +from text_generation_server.layers.tensor_parallel import TensorParallelHead + + +class SpeculativeHead(torch.nn.Module): + def __init__(self, lm_head, medusa): + super().__init__() + self.head = lm_head + self.medusa = medusa + + @staticmethod + def load(config, prefix: str, weights): + use_medusa = config.use_medusa + if use_medusa: + lm_head = None + try: + medusa = MedusaHeadV1.load(config, prefix, weights) + except: + medusa = MedusaHeadV2(config, prefix, weights) + else: + lm_head = TensorParallelHead.load(config, prefix, weights) + medusa = None + return SpeculativeHead(lm_head, medusa) + + def forward( + self, input: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if self.medusa is not None: + return self.medusa(input) + + assert self.head is not None + logits = self.head(input) + return logits, None diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py new file mode 100644 index 00000000..34b9c51e --- /dev/null +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -0,0 +1,188 @@ +import torch +from torch.nn import functional as F +from typing import List +from text_generation_server.layers.linear import get_linear, FastLinear + + +class SuperLayer(torch.nn.Module): + def __init__(self, linear): + super().__init__() + self.linear = linear + + def forward(self, x): + return self.linear.forward(x) + + +class TensorParallelHead(SuperLayer): + def __init__(self, linear, process_group, should_gather: bool): + super().__init__(linear) + self.process_group = process_group + self.should_gather = should_gather + + @staticmethod + def load(config, prefix: str, weights): + if weights.process_group.size() > 1: + try: + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + should_gather = True + except AssertionError: + # If the vocab size is not divisible by number of shards + # just load the entire thing. + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False + else: + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False + + # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings) + if config.quantize in ["gptq", "awq", "eetq"]: + quantize = None + else: + quantize = config.quantize + return TensorParallelHead( + get_linear(weight, bias=None, quantize=quantize), + process_group=weights.process_group, + should_gather=should_gather, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not self.should_gather: + return super().forward(input) + + world_size = self.process_group.size() + if len(input.shape) == 2 and isinstance(self.linear, FastLinear): + out_dim = self.linear.weight.shape[0] + + if input.shape[0] == 1: + world_out = input.new_empty(1, out_dim * world_size) + local_out = input.new_empty(1, out_dim) + gather_input = local_out + else: + world_out = input.new_empty(out_dim * world_size, input.shape[0]) + gather_input = input.new_empty(out_dim, input.shape[0]) + local_out = gather_input.T + + torch.mm(input, self.linear.weight.T, out=local_out) + + torch.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) + + if input.shape[0] == 1: + return world_out + return world_out.T + + output = super().forward(input) + world_output = [ + torch.empty_like(output) for _ in range(self.process_group.size()) + ] + torch.distributed.all_gather(world_output, output, group=self.process_group) + world_output = torch.cat(world_output, dim=-1) + return world_output + + +class TensorParallelColumnLinear(SuperLayer): + @classmethod + def load_gate_up(cls, config, prefix: str, weights, bias: bool): + """Specific method when the QKV was joined after the fact""" + weight = weights.get_weights_col_packed_gate_up( + prefix, quantize=config.quantize + ) + if bias: + raise NotImplementedError("packed_gate_up only implemented without bias") + else: + bias = None + linear = get_linear(weight, bias, config.quantize) + return cls(linear) + + @classmethod + def load_qkv(cls, config, prefix: str, weights, bias: bool): + """Specific method when the QKV was joined after the fact""" + weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize) + if bias: + raise NotImplementedError("packed_qkv only implemented for baichuan") + else: + bias = None + linear = get_linear(weight, bias, config.quantize) + return cls(linear) + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + return cls.load_multi(config, [prefix], weights, bias, dim=0) + + @classmethod + def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): + weight = weights.get_multi_weights_col( + prefixes, quantize=config.quantize, dim=dim + ) + + if bias: + b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] + bias = torch.cat(b, dim=dim) + else: + bias = None + linear = get_linear(weight, bias, config.quantize) + return cls(linear) + + +class TensorParallelRowLinear(SuperLayer): + def __init__(self, linear, process_group): + super().__init__(linear) + self.process_group = process_group + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return cls( + get_linear(weight, bias, config.quantize), + process_group=weights.process_group, + ) + + def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: + out = super().forward(input) + if self.process_group.size() > 1 and reduce: + torch.distributed.all_reduce(out, group=self.process_group) + return out + + +class TensorParallelEmbedding(torch.nn.Module): + def __init__(self, prefix: str, weights, reduce=True): + super().__init__() + weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) + num_embeddings = weights.get_shape(f"{prefix}.weight")[0] + + process_group = weights.process_group + + world_size = process_group.size() + rank = process_group.rank() + + block_size = (num_embeddings + world_size - 1) // world_size + self.min_id = rank * block_size + self.max_id = min(num_embeddings, (rank + 1) * block_size) + self.null_idx = weight.shape[ + 0 + ] # Usually block_size, might be less in non even vocab_size. + self.process_group = weights.process_group + self.reduce = reduce + + """Additional 0 entry used for masking""" + self.weight = torch.nn.Parameter(F.pad(weight, (0, 0, 0, 1))) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # default all out of bounds values to `self.null_idx` that will then be mapped to 0 + # translate for [0, self.max_id - self.min_id[ + input = torch.where( + (self.min_id > input) | (input >= self.max_id), + self.null_idx, + input - self.min_id, + ) + out = torch.nn.functional.embedding(input, self.weight) + if self.reduce and self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + return out diff --git a/server/text_generation_server/models/cache_manager.py b/server/text_generation_server/models/cache_manager.py index 4c65e2dd..c7705fe8 100644 --- a/server/text_generation_server/models/cache_manager.py +++ b/server/text_generation_server/models/cache_manager.py @@ -2,7 +2,7 @@ import math import torch from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import SYSTEM BLOCK_SIZE: int = 16 # Will be set in warmup @@ -25,7 +25,7 @@ class CacheManager: self.repeat_slots = repeat_slots element_size = torch.tensor([], dtype=dtype).element_size() - if IS_XPU_SYSTEM: + if SYSTEM == "xpu": x = 1 else: x = self.block_size // element_size diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index c8f02bca..0d8a1b59 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -32,7 +32,7 @@ from transformers.modeling_outputs import ( ) from transformers import BloomConfig, PreTrainedModel -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, diff --git a/server/text_generation_server/models/custom_modeling/clip.py b/server/text_generation_server/models/custom_modeling/clip.py index c4917733..56618bf1 100644 --- a/server/text_generation_server/models/custom_modeling/clip.py +++ b/server/text_generation_server/models/custom_modeling/clip.py @@ -15,7 +15,7 @@ from transformers.modeling_outputs import ( ) from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelEmbedding, TensorParallelColumnLinear, TensorParallelRowLinear, diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 56d9a966..8c423eaf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -26,18 +26,22 @@ from transformers.activations import ACT2FN from typing import Optional, List, Tuple from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM, IS_CUDA_SYSTEM -from text_generation_server.utils.layers import ( +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - PositionRotaryEmbedding, SpeculativeHead, get_linear, +) +from text_generation_server.layers.layernorm import ( FastLayerNorm, ) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) -if IS_CUDA_SYSTEM: +if SYSTEM == "cuda": import dropout_layer_norm else: dropout_layer_norm = None @@ -52,7 +56,7 @@ class CohereRotary(PositionRotaryEmbedding): sin: torch.Tensor, ): # Such controlflows may add some overhead. - if IS_CUDA_SYSTEM: + if SYSTEM == "cuda": import rotary_emb q1 = query[..., ::2] @@ -64,7 +68,7 @@ class CohereRotary(PositionRotaryEmbedding): k2 = key[..., 1::2] rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - elif IS_ROCM_SYSTEM: + elif SYSTEM == "rocm": from vllm import pos_encoding_ops # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. @@ -90,7 +94,7 @@ class CohereLayerNorm(nn.Module): self.eps = eps def forward(self, hidden_states): - if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: + if hidden_states.shape[-1] > 8192 or SYSTEM == "rocm": hidden_states = hidden_states.reshape( -1, self.weight.shape[0], self.weight.shape[1] ) diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index d0978bef..9d652b67 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -21,21 +21,26 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any from loguru import logger -from text_generation_server.utils.import_utils import IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import SYSTEM -if not IS_XPU_SYSTEM: +if SYSTEM != "xpu": from vllm.model_executor.layers.fused_moe import fused_moe + from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( FastLinear, - FastLayerNorm, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - PositionRotaryEmbedding, SpeculativeHead, get_linear, ) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) +from text_generation_server.layers.layernorm import ( + FastLayerNorm, +) from text_generation_server.utils.log import log_once @@ -216,7 +221,7 @@ def _load_gqa(config, prefix: str, weights): bits, groupsize, desc_act, quant_method = weights._get_gptq_params() - from text_generation_server.utils.layers import HAS_EXLLAMA + from text_generation_server.layers import HAS_EXLLAMA use_exllama = ( bits == 4 and HAS_EXLLAMA and config.quantize == "gptq" and not desc_act @@ -236,7 +241,7 @@ def _load_gqa(config, prefix: str, weights): log_once( logger.info, "Converting AWQ model to Exllama/GPTQ packing format." ) - from text_generation_server.utils.awq.conversion_utils import ( + from text_generation_server.layers.awq.conveersion_utils import ( fast_awq_to_gptq, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index bd7596db..43b90bdd 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -27,13 +27,15 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - PositionRotaryEmbedding, SpeculativeHead, get_linear, +) +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.layers.layernorm import ( FastRMSNorm, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 6fa85d4e..a7969494 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -27,13 +27,15 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - PositionRotaryEmbedding, SpeculativeHead, get_linear, +) +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.layers.layernorm import ( FastRMSNorm, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index c2445cda..3e13c26d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -27,13 +27,15 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - PositionRotaryEmbedding, SpeculativeHead, get_linear, +) +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.layers.layernorm import ( FastRMSNorm, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 3f6c8e03..be2d6c45 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -24,9 +24,9 @@ import torch.distributed import numpy as np from torch import nn -from text_generation_server.utils.import_utils import IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import SYSTEM -if not IS_XPU_SYSTEM: +if SYSTEM != "xpu": from vllm.model_executor.layers.fused_moe import fused_moe from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig @@ -34,16 +34,20 @@ from typing import Optional, List, Tuple from loguru import logger from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( FastLinear, - FastRMSNorm, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - PositionRotaryEmbedding, SpeculativeHead, get_linear, ) +from text_generation_server.layers.layernorm import ( + FastRMSNorm, +) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) class MixtralConfig(PretrainedConfig): diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index ee062d3d..d45cab2e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -29,15 +29,19 @@ from typing import Optional, List, Tuple from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.flash_attn import attention -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, SpeculativeHead, - FastLayerNorm, - PositionRotaryEmbedding, get_linear, ) +from text_generation_server.layers.layernorm import ( + FastLayerNorm, +) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) def load_row(config, prefix: str, weights, bias: bool): diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index cfe447a7..f2efb538 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -7,15 +7,19 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - PositionRotaryEmbedding, SpeculativeHead, get_linear, +) +from text_generation_server.layers.layernorm import ( FastLayerNorm, ) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) class PhiConfig(PretrainedConfig): diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 94023b33..3a6d2db5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -6,13 +6,15 @@ from transformers.activations import ACT2FN from typing import Optional, List, Tuple from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - PositionRotaryEmbedding, SpeculativeHead, get_linear, +) +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.layers.layernorm import ( FastRMSNorm, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index a9127d1f..52ea3ae1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -8,15 +8,19 @@ from typing import Optional, List, Tuple from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.flash_attn import attention -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, SpeculativeHead, - FastLayerNorm, - PositionRotaryEmbedding, get_linear, ) +from text_generation_server.layers.layernorm import ( + FastLayerNorm, +) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) def load_row(config, prefix: str, weights, bias: bool): diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index bbb603a7..d2f6d9af 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -6,14 +6,16 @@ from transformers.activations import ACT2FN from typing import Optional, List, Tuple from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, SpeculativeHead, TensorParallelEmbedding, - FastLayerNorm, get_linear, ) +from text_generation_server.layers.layernorm import ( + FastLayerNorm, +) def load_multi_mqa( @@ -80,13 +82,13 @@ def _load_multi_mqa_gptq( g_idx = g_idx.to(device=weights.device) elif quant_method == "awq": g_idx = None - from text_generation_server.utils.awq.conversion_utils import ( + from text_generation_server.layers.awq.conversion_utils import ( fast_awq_to_gptq, ) qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - from text_generation_server.utils.layers import HAS_EXLLAMA + from text_generation_server.layers.gptq import HAS_EXLLAMA use_exllama = HAS_EXLLAMA weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index ed77af78..3e2ce4f9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -27,15 +27,19 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - PositionRotaryEmbedding, SpeculativeHead, get_linear, - FastRMSNorm, +) +from text_generation_server.layers.layernorm import ( FastLayerNorm, + FastRMSNorm, +) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, ) diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index cb2ee7db..935f049b 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -29,7 +29,7 @@ from text_generation_server.models.custom_modeling.vlm import ( ) from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index ee4cdb08..ec3f900b 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -47,20 +47,22 @@ from text_generation_server.models.custom_modeling.idefics_vision import ( from text_generation_server.models.custom_modeling.idefics_perceiver import ( IdeficsPerceiverResampler, ) -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, SpeculativeHead, - PositionRotaryEmbedding, FastLinear, ) -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.utils.import_utils import SYSTEM -if IS_CUDA_SYSTEM: +if SYSTEM == "cuda": import dropout_layer_norm -elif IS_ROCM_SYSTEM: +elif SYSTEM == "rocm": from vllm import layernorm_ops +else: + raise RuntimeError(f"Unsupported system {SYSTEM}") @dataclass @@ -373,7 +375,7 @@ class IdeficsRMSNorm(nn.Module): hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states - elif IS_CUDA_SYSTEM: + elif SYSTEM == "cuda": # faster post attention rms norm unwrap = False if len(hidden_states.shape) > 2: @@ -405,7 +407,7 @@ class IdeficsRMSNorm(nn.Module): normed_hidden_states = normed_hidden_states.view(*shape) return normed_hidden_states - elif IS_ROCM_SYSTEM: + elif SYSTEM == "rocm": # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. if residual is not None: hidden_states += residual diff --git a/server/text_generation_server/models/custom_modeling/idefics_perceiver.py b/server/text_generation_server/models/custom_modeling/idefics_perceiver.py index 477d4d70..af44490b 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_perceiver.py +++ b/server/text_generation_server/models/custom_modeling/idefics_perceiver.py @@ -41,7 +41,7 @@ from typing import Optional, Tuple import torch import torch.nn as nn -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelRowLinear, ) diff --git a/server/text_generation_server/models/custom_modeling/idefics_vision.py b/server/text_generation_server/models/custom_modeling/idefics_vision.py index c521dd0a..30c5997f 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_vision.py +++ b/server/text_generation_server/models/custom_modeling/idefics_vision.py @@ -28,7 +28,7 @@ from transformers.utils import ( ModelOutput, logging, ) -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelRowLinear, TensorParallelEmbedding, diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 0d93791f..a049f756 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -27,7 +27,7 @@ from text_generation_server.models.custom_modeling.vlm import ( load_text_model, load_vision_model, ) -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelRowLinear, ) diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index c58a617f..293051c2 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -8,12 +8,12 @@ from typing import Optional, Tuple, Any from transformers.configuration_utils import PretrainedConfig import torch.nn.functional as F -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( SpeculativeHead, TensorParallelEmbedding, - FastRMSNorm, FastLinear, ) +from text_generation_server.layers.layernorm import FastRMSNorm from einops import rearrange from causal_conv1d import causal_conv1d_fn, causal_conv1d_update diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index 9b0f8b92..f7981bf5 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -17,7 +17,7 @@ from transformers.modeling_outputs import ( ) from einops import rearrange from packaging import version -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelEmbedding, TensorParallelColumnLinear, TensorParallelRowLinear, diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index f060ec0e..fcad32fa 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -40,7 +40,7 @@ from transformers.modeling_outputs import ( from transformers.modeling_utils import PreTrainedModel from transformers import GPTNeoXConfig from loguru import logger -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index 7a5cf917..83d62dea 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -27,7 +27,7 @@ from transformers.modeling_outputs import ( ) from transformers.modeling_utils import PreTrainedModel from transformers import OPTConfig -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( FastLinear, TensorParallelColumnLinear, TensorParallelEmbedding, diff --git a/server/text_generation_server/models/custom_modeling/phi_modeling.py b/server/text_generation_server/models/custom_modeling/phi_modeling.py index 1571f9fd..04b470eb 100644 --- a/server/text_generation_server/models/custom_modeling/phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/phi_modeling.py @@ -9,7 +9,7 @@ from typing import Optional, List, Tuple, Any from transformers.configuration_utils import PretrainedConfig from transformers.modeling_outputs import CausalLMOutputWithPast -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 2773fb15..0b899fba 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -38,7 +38,7 @@ from transformers.utils import ( is_torch_fx_proxy, ) from transformers import T5Config -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a6d0204f..f567bea9 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -12,7 +12,6 @@ from dataclasses import dataclass from opentelemetry import trace from transformers import PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Dict - from text_generation_server.models import Model from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.speculate import get_speculate @@ -32,13 +31,14 @@ from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION -tracer = trace.get_tracer(__name__) from text_generation_server.utils.import_utils import ( - IS_CUDA_SYSTEM, - IS_ROCM_SYSTEM, - IS_XPU_SYSTEM, + empty_cache, + synchronize, + get_free_memory, ) +tracer = trace.get_tracer(__name__) + @dataclass class FlashCausalLMBatch(Batch): @@ -757,10 +757,8 @@ class FlashCausalLM(Model): def warmup(self, batch: FlashCausalLMBatch): # The warmup batch is the biggest batch we could ever receive - if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: - torch.cuda.empty_cache() - elif IS_XPU_SYSTEM: - torch.xpu.empty_cache() + empty_cache() + try: cache_manager = set_cache_manager( batch.blocks, @@ -780,10 +778,7 @@ class FlashCausalLM(Model): f"You need to decrease `--max-batch-prefill-tokens`" ) from e - if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: - torch.cuda.synchronize(self.device) - elif IS_XPU_SYSTEM: - torch.xpu.synchronize(self.device) + synchronize(self.device) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory @@ -791,20 +786,7 @@ class FlashCausalLM(Model): cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size - if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: - total_free_memory, _ = torch.cuda.mem_get_info(self.device) - total_gpu_memory = torch.cuda.get_device_properties( - self.device - ).total_memory - - free_memory = max( - 0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory - ) - elif IS_XPU_SYSTEM: - total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory - free_memory = int(total_gpu_memory * 0.5) - else: - raise NotImplementedError("FlashModel is only available on GPU") + free_memory = get_free_memory(self.device, MEMORY_FRACTION) num_blocks = ( # Leave 5% for some wiggle room diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 609a188d..8ea70713 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -18,7 +18,7 @@ from text_generation_server.utils import ( tracer = trace.get_tracer(__name__) -from text_generation_server.utils.import_utils import IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import SYSTEM class FlashLlama(FlashCausalLM): @@ -35,7 +35,7 @@ class FlashLlama(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif IS_XPU_SYSTEM: + elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype else: diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 85e93543..48304ad8 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -33,7 +33,7 @@ tracer = trace.get_tracer(__name__) # Will be set in init SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW_BLOCKS: Optional[int] = None -from text_generation_server.utils.import_utils import IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import SYSTEM MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None @@ -322,7 +322,7 @@ class BaseFlashMistral(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif IS_XPU_SYSTEM: + elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype else: diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index f82e27db..1119bdae 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -14,7 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -33,7 +33,7 @@ class FlashNeoXSharded(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif IS_XPU_SYSTEM: + elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype else: diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index ccf38a0c..33298e1a 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -15,7 +15,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -34,7 +34,7 @@ class FlashRWSharded(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif IS_XPU_SYSTEM: + elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype else: diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index e66f1bf8..66698a3a 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -18,7 +18,7 @@ from text_generation_server.utils import ( Weights, ) -from text_generation_server.utils.import_utils import IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -37,7 +37,7 @@ class FlashSantacoderSharded(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif IS_XPU_SYSTEM: + elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype else: diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 158966e3..9d0571a6 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -85,7 +85,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): # When using GPTQ, Exllama kernels need some global kernels # For which we have the finale shapes only after the model has loaded # This will allocate those buffers. - from text_generation_server.utils.layers import ( + from text_generation_server.layers.gptq import ( create_exllama_buffers, set_device, ) diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 583a8f91..0830656d 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -2,13 +2,8 @@ import os import torch from loguru import logger -import math -from text_generation_server.utils.import_utils import ( - IS_CUDA_SYSTEM, - IS_ROCM_SYSTEM, - IS_XPU_SYSTEM, -) +from text_generation_server.utils.import_utils import SYSTEM if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") @@ -16,83 +11,22 @@ HAS_FLASH_ATTN = True HAS_FLASH_ATTN_V2_CUDA = False HAS_FLASH_ATTN_V2_ROCM = False -if IS_XPU_SYSTEM: +if SYSTEM == "xpu": import intel_extension_for_pytorch as ipex -if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: - if not torch.cuda.is_available(): - raise ImportError("CUDA is not available") + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + ): + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") - major, minor = torch.cuda.get_device_capability() - is_sm75 = major == 7 and minor == 5 - is_sm8x = major == 8 and minor >= 0 - is_sm90 = major == 9 and minor == 0 - - HAS_FLASH_ATTN = False - HAS_FLASH_ATTN_V2_CUDA = False - HAS_FLASH_ATTN_V2_ROCM = False - try: - try: - import flash_attn_2_cuda - except ImportError: - architecture_suffix = "" - if IS_CUDA_SYSTEM: - architecture_suffix = "-cuda" - elif IS_ROCM_SYSTEM: - architecture_suffix = "-rocm" - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" - ) - if not (is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported for " - "Flash Attention V2" - ) - HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM - HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM - except ImportError as e: - try: - import flash_attn_cuda - except ImportError: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e - - if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported" - ) from e - elif IS_ROCM_SYSTEM: - for idx in range(torch.cuda.device_count()): - if "MI210" not in torch.cuda.get_device_name( - idx - ) and "MI250" not in torch.cuda.get_device_name(idx): - raise ImportError( - f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" - ) - - logger.warning(f"Unable to use Flash Attention V2: {e}") - HAS_FLASH_ATTN = True - - -def attention( - q, - k, - v, - out, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, -): - if window_size_left <= 0 and window_size_left != -1: - raise ValueError("`window_size_left` must be > 0 or -1") - - if IS_XPU_SYSTEM: if window_size_left != -1: raise ValueError( f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." @@ -114,7 +48,77 @@ def attention( None, ) - if HAS_FLASH_ATTN_V2_CUDA: + +if SYSTEM in {"cuda", "rocm"}: + if not torch.cuda.is_available(): + raise ImportError("CUDA is not available") + + major, minor = torch.cuda.get_device_capability() + is_sm75 = major == 7 and minor == 5 + is_sm8x = major == 8 and minor >= 0 + is_sm90 = major == 9 and minor == 0 + + HAS_FLASH_ATTN = False + HAS_FLASH_ATTN_V2_CUDA = False + HAS_FLASH_ATTN_V2_ROCM = False + try: + try: + import flash_attn_2_cuda + except ImportError: + architecture_suffix = f"-{SYSTEM}" + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" + ) + if not (is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported for " + "Flash Attention V2" + ) + HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda" + HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm" + except ImportError as e: + try: + import flash_attn_cuda + except ImportError: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + + if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported" + ) from e + elif SYSTEM == "rocm": + for idx in range(torch.cuda.device_count()): + if "MI210" not in torch.cuda.get_device_name( + idx + ) and "MI250" not in torch.cuda.get_device_name(idx): + raise ImportError( + f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" + ) + + logger.warning(f"Unable to use Flash Attention V2: {e}") + HAS_FLASH_ATTN = True + + +if HAS_FLASH_ATTN_V2_CUDA: + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + ): + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") return flash_attn_2_cuda.varlen_fwd( q, k, @@ -136,7 +140,21 @@ def attention( False, None, ) - elif HAS_FLASH_ATTN_V2_ROCM: + +elif HAS_FLASH_ATTN_V2_ROCM: + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + ): + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") if window_size_left != -1: raise ValueError( f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." @@ -159,7 +177,19 @@ def attention( False, None, ) - elif HAS_FLASH_ATTN: + +elif HAS_FLASH_ATTN: + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + ): if window_size_left != -1: raise NotImplementedError( "window_size_left is only available with flash attn v2" @@ -209,4 +239,5 @@ def attention( None, ) +else: raise NotImplementedError("flash attention is not installed") diff --git a/server/text_generation_server/utils/gptq/quant_linear.py b/server/text_generation_server/utils/gptq/quant_linear.py deleted file mode 100644 index a832f755..00000000 --- a/server/text_generation_server/utils/gptq/quant_linear.py +++ /dev/null @@ -1,359 +0,0 @@ -import math -import numpy as np -import torch -import torch.nn as nn -from torch.cuda.amp import custom_bwd, custom_fwd - -try: - import triton - import triton.language as tl - from . import custom_autotune - - # code based https://github.com/fpgaminer/GPTQ-triton - @custom_autotune.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=4, - ), - ], - key=["M", "N", "K"], - nearest_power_of_two=True, - prune_configs_by={ - "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, - "perf_model": None, - "top_k": None, - }, - ) - @triton.jit - def matmul_248_kernel( - a_ptr, - b_ptr, - c_ptr, - scales_ptr, - zeros_ptr, - g_ptr, - M, - N, - K, - bits, - maxq, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_scales, - stride_zeros, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - ): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, K) float16 - B is of shape (K//8, N) int32 - C is of shape (M, N) float16 - scales is of shape (G, N) float16 - zeros is of shape (G, N) float16 - g_ptr is of shape (K) int32 - """ - infearure_per_bits = 32 // bits - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + ( - offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak - ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = offs_am[:, None] < M - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ( - (offs_k[:, None] // infearure_per_bits) * stride_bk - + offs_bn[None, :] * stride_bn - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - g_ptrs = g_ptr + offs_k - # shifter is used to extract the N bits of each element in the 32-bit word from B - scales_ptrs = scales_ptr + offs_bn[None, :] - zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) - - shifter = (offs_k % infearure_per_bits) * bits - zeros_shifter = (offs_bn % infearure_per_bits) * bits - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for k in range(0, num_pid_k): - g_idx = tl.load(g_ptrs) - - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load( - scales_ptrs + g_idx[:, None] * stride_scales - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load( - zeros_ptrs + g_idx[:, None] * stride_zeros - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - - zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = (zeros + 1) & maxq # eventually avoid overflow - - a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - - # Now we need to unpack b (which is N-bit values) into 32-bit values - b = (b >> shifter[:, None]) & maxq # Extract the N-bit values - b = (b - zeros) * scales # Scale and shift - - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K - b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - g_ptrs += BLOCK_SIZE_K - - c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - -except: - print("triton not installed.") - - -def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): - with torch.cuda.device(input.device): - output = torch.empty( - (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 - ) - grid = lambda META: ( - triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) - * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), - ) - matmul_248_kernel[grid]( - input, - qweight, - output, - scales, - qzeros, - g_idx, - input.shape[0], - qweight.shape[1], - input.shape[1], - bits, - maxq, - input.stride(0), - input.stride(1), - qweight.stride(0), - qweight.stride(1), - output.stride(0), - output.stride(1), - scales.stride(0), - qzeros.stride(0), - ) - return output - - -class QuantLinearFunction(torch.autograd.Function): - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): - output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq) - return output - - -class QuantLinear(nn.Module): - def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): - super().__init__() - self.register_buffer("qweight", qweight) - self.register_buffer("qzeros", qzeros) - self.register_buffer("scales", scales) - self.register_buffer("g_idx", g_idx) - if bias is not None: - self.register_buffer("bias", bias) - else: - self.bias = None - if bits not in [2, 4, 8]: - raise NotImplementedError("Only 2,4,8 bits are supported.") - self.bits = bits - self.maxq = 2**self.bits - 1 - self.groupsize = groupsize - - self.outfeatures = qweight.shape[1] - self.infeatures = qweight.shape[0] * 32 // bits - - @classmethod - def new(cls, bits, groupsize, infeatures, outfeatures, bias): - if bits not in [2, 4, 8]: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) - qzeros = torch.zeros( - (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), - dtype=torch.int32, - ) - scales = torch.zeros( - (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 - ) - g_idx = torch.tensor( - [i // groupsize for i in range(infeatures)], dtype=torch.int32 - ) - if bias: - bias = torch.zeros((outfeatures), dtype=torch.float16) - else: - bias = None - return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) - - def pack(self, linear, scales, zeros, g_idx=None): - self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx - - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() - scale_zeros = zeros * scales - self.scales = scales.clone().half() - if linear.bias is not None: - self.bias = linear.bias.clone().half() - - intweight = [] - for idx in range(self.infeatures): - intweight.append( - torch.round( - (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) - / self.scales[self.g_idx[idx]] - ).to(torch.int)[:, None] - ) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() - intweight = intweight.numpy().astype(np.uint32) - qweight = np.zeros( - (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 - ) - i = 0 - row = 0 - while row < qweight.shape[0]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qweight[row] |= intweight[j] << (self.bits * (j - i)) - i += 32 // self.bits - row += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qweight = qweight.astype(np.int32) - self.qweight = torch.from_numpy(qweight) - - zeros -= 1 - zeros = zeros.numpy().astype(np.uint32) - qzeros = np.zeros( - (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 - ) - i = 0 - col = 0 - while col < qzeros.shape[1]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) - i += 32 // self.bits - col += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qzeros = qzeros.astype(np.int32) - self.qzeros = torch.from_numpy(qzeros) - - def forward(self, x): - out_shape = x.shape[:-1] + (self.outfeatures,) - out = QuantLinearFunction.apply( - x.reshape(-1, x.shape[-1]), - self.qweight, - self.scales, - self.qzeros, - self.g_idx, - self.bits, - self.maxq, - ) - out = out + self.bias if self.bias is not None else out - return out.reshape(out_shape) diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index db205f4d..f54987eb 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -10,6 +10,41 @@ def is_xpu_available(): return hasattr(torch, "xpu") and torch.xpu.is_available() -IS_ROCM_SYSTEM = torch.version.hip is not None -IS_CUDA_SYSTEM = torch.version.cuda is not None -IS_XPU_SYSTEM = is_xpu_available() +def get_cuda_free_memory(device, memory_fraction): + total_free_memory, _ = torch.cuda.mem_get_info(device) + total_gpu_memory = torch.cuda.get_device_properties(device).total_memory + free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory) + return free_memory + + +def get_xpu_free_memory(device): + total_gpu_memory = torch.xpu.get_device_properties(device).total_memory + free_memory = int(total_gpu_memory * 0.5) + return free_memory + + +SYSTEM = None +if torch.version.hip is not None: + SYSTEM = "rocm" + empty_cache = torch.cuda.empty_cache + synchronize = torch.cuda.synchronize + get_free_memory = get_cuda_free_memory +elif torch.version.cuda is not None and torch.cuda.is_available(): + SYSTEM = "cuda" + empty_cache = torch.cuda.empty_cache + synchronize = torch.cuda.synchronize + get_free_memory = get_cuda_free_memory +elif is_xpu_available(): + SYSTEM = "xpu" + empty_cache = torch.xpu.empty_cache + synchronize = torch.xpu.synchronize + get_free_memory = get_xpu_free_memory +else: + SYSTEM = "cpu" + + def noop(*args, **kwargs): + pass + + empty_cache = noop + synchronize = noop + get_free_memory = noop diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py deleted file mode 100644 index 7d339fe5..00000000 --- a/server/text_generation_server/utils/layers.py +++ /dev/null @@ -1,1374 +0,0 @@ -import os -import torch -import torch.distributed - -from torch import nn -from torch.nn import functional as F -from typing import List, Tuple, Optional -from loguru import logger -from functools import lru_cache - -from text_generation_server.utils.speculate import get_speculate - -HAS_BITS_AND_BYTES = True -try: - import bitsandbytes as bnb - from bitsandbytes.nn import Int8Params, Params4bit -except ImportError: - HAS_BITS_AND_BYTES = False - -from accelerate import init_empty_weights - -from text_generation_server.utils.gptq.quant_linear import QuantLinear -from text_generation_server.utils.import_utils import ( - IS_CUDA_SYSTEM, - IS_ROCM_SYSTEM, - IS_XPU_SYSTEM, -) - -if IS_XPU_SYSTEM: - import intel_extension_for_pytorch as ipex - -HAS_AWQ = True -try: - from text_generation_server.utils.awq.quantize.qmodule import WQLinear -except ImportError: - HAS_AWQ = False - -try: - major, _minor = torch.cuda.get_device_capability() -except Exception: - major = 1 - -HAS_EXLLAMA = False -CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM -V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" - -if os.getenv("DISABLE_EXLLAMA") == "True": - HAS_EXLLAMA = False -elif CAN_EXLLAMA: - try: - if V2: - from text_generation_server.utils.gptq.exllamav2 import ( - QuantLinear as ExllamaQuantLinear, - create_exllama_buffers, - set_device, - ) - - HAS_EXLLAMA = "2" - else: - from text_generation_server.utils.gptq.exllama import ( - Ex4bitLinear as ExllamaQuantLinear, - create_exllama_buffers, - set_device, - ) - - HAS_EXLLAMA = "1" - - except ImportError: - pass - -HAS_EETQ = False -try: - from EETQ import quant_weights, w8_a16_gemm - - HAS_EETQ = True -except ImportError: - pass - - -# Monkey patching -@classmethod -def load_layer_norm(cls, prefix, weights, eps): - weight = weights.get_tensor(f"{prefix}.weight") - bias = weights.get_tensor(f"{prefix}.bias") - with init_empty_weights(): - ln = cls(weight.shape, eps=eps) - - ln.weight = nn.Parameter(weight) - ln.bias = nn.Parameter(bias) - return ln - - -@classmethod -def load_layer_norm_no_bias(cls, prefix, weights, eps): - weight = weights.get_tensor(f"{prefix}.weight") - with init_empty_weights(): - ln = cls(weight.shape, eps=eps) - - ln.weight = nn.Parameter(weight) - ln.bias = None - return ln - - -@classmethod -def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride): - weight = weights.get_tensor(f"{prefix}.weight") - bias = weights.get_tensor(f"{prefix}.bias") - with init_empty_weights(): - conv2d = cls( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - ) - - conv2d.weight = nn.Parameter(weight) - conv2d.bias = nn.Parameter(bias) - return conv2d - - -@classmethod -def load_conv2d_no_bias( - cls, prefix, weights, in_channels, out_channels, kernel_size, stride -): - weight = weights.get_tensor(f"{prefix}.weight") - with init_empty_weights(): - conv2d = cls( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - ) - - conv2d.weight = nn.Parameter(weight) - conv2d.bias = None - return conv2d - - -torch.nn.Conv2d.load = load_conv2d -torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias -torch.nn.LayerNorm.load = load_layer_norm -torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias - - -class FastLinear(nn.Module): - def __init__( - self, - weight, - bias, - ) -> None: - super().__init__() - self.weight = nn.Parameter(weight) - if bias is not None: - self.bias = nn.Parameter(bias) - else: - self.bias = None - - @classmethod - def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_tensor(f"{prefix}.weight") - if bias: - bias = weights.get_tensor(f"{prefix}.bias") - else: - bias = None - return cls(weight, bias) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return F.linear(input, self.weight, self.bias) - - -class EETQLinear(nn.Module): - def __init__( - self, - weight, - bias, - ) -> None: - super().__init__() - device = weight.device - if weight.dtype != torch.float16: - weight = weight.to(dtype=torch.float16) - weight = torch.t(weight).contiguous().cpu() - weight, scale = quant_weights(weight, torch.int8, False) - - self.weight = weight.cuda(device) - self.scale = scale.cuda(device) - self.bias = bias.cuda(device) if bias is not None else None - - def forward(self, input: torch.Tensor) -> torch.Tensor: - output = w8_a16_gemm(input, self.weight, self.scale) - output = output + self.bias if self.bias is not None else output - return output - - -def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): - device = weight.device - # weight, scale = quant_weights(weight, torch.int8, False) - finfo = torch.finfo(qdtype) - # Calculate the scale as dtype max divided by absmax - scale = finfo.max / weight.abs().max().clamp(min=1e-12) - # scale and clamp the tensor to bring it to - # the representative range of float8 data type - # (as default cast is unsaturated) - qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) - # Return both float8 data and the inverse scale (as float), - # as both required as inputs to torch._scaled_mm - qweight = qweight.to(qdtype) - scale = scale.float().reciprocal() - return qweight, scale - - -class Fp8Linear(nn.Module): - def __init__( - self, - weight, - bias, - ) -> None: - super().__init__() - self.dtype = weight.dtype - self.qweight, self.scale = fp8_quantize(weight) - - self.bias = bias if bias is not None else None - - def forward(self, input: torch.Tensor) -> torch.Tensor: - qinput, scale = fp8_quantize(input) - output, _ = torch._scaled_mm( - qinput, - self.qweight.t(), - out_dtype=self.dtype, - scale_a=scale, - scale_b=self.scale, - bias=self.bias, - ) - return output - - -class Linear8bitLt(nn.Module): - def __init__( - self, - weight, - bias, - has_fp16_weights=True, - memory_efficient_backward=False, - threshold=0.0, - index=None, - ): - super().__init__() - assert ( - not memory_efficient_backward - ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" - self.state = bnb.MatmulLtState() - self.index = index - - # Necessary for stacked layers - self.state.threshold = threshold - self.state.has_fp16_weights = has_fp16_weights - self.state.memory_efficient_backward = memory_efficient_backward - if threshold > 0.0 and not has_fp16_weights: - self.state.use_pool = True - - self.weight = Int8Params( - weight.data, - has_fp16_weights=has_fp16_weights, - requires_grad=has_fp16_weights, - ) - self.weight.cuda(weight.device) - self.bias = bias - - def init_8bit_state(self): - self.state.CB = self.weight.CB - self.state.SCB = self.weight.SCB - self.weight.CB = None - self.weight.SCB = None - - def forward(self, x: torch.Tensor): - self.state.is_training = self.training - if self.weight.CB is not None: - self.init_8bit_state() - - # weights are cast automatically as Int8Params, but the bias has to be cast manually - if self.bias is not None and self.bias.dtype != x.dtype: - self.bias.data = self.bias.data.to(x.dtype) - - out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - - if not self.state.has_fp16_weights: - if self.state.CB is not None and self.state.CxB is not None: - # we converted 8-bit row major to turing/ampere format in the first inference pass - # we no longer need the row-major weight - del self.state.CB - self.weight.data = self.state.CxB - return out - - -class Linear4bit(nn.Module): - def __init__(self, weight, bias, quant_type): - super().__init__() - self.weight = Params4bit( - weight.data, - requires_grad=False, - compress_statistics=True, - quant_type=quant_type, - ) - self.compute_dtype = None - self.weight.cuda(weight.device) - self.bias = bias - - def forward(self, x: torch.Tensor): - # weights are cast automatically as Int8Params, but the bias has to be cast manually - if self.bias is not None and self.bias.dtype != x.dtype: - self.bias.data = self.bias.data.to(x.dtype) - - if getattr(self.weight, "quant_state", None) is None: - print( - "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first." - ) - inp_dtype = x.dtype - if self.compute_dtype is not None: - x = x.to(self.compute_dtype) - - bias = None if self.bias is None else self.bias.to(self.compute_dtype) - out = bnb.matmul_4bit( - x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state - ) - - out = out.to(inp_dtype) - - return out - - -@lru_cache(1) -def warn_deprecate_bnb(): - logger.warning( - "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce" - ) - - -def get_linear(weight, bias, quantize): - if quantize is None: - linear = FastLinear(weight, bias) - elif quantize == "eetq": - if HAS_EETQ: - linear = EETQLinear(weight, bias) - else: - raise ImportError( - "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" - ) - elif quantize == "fp8": - linear = Fp8Linear(weight, bias) - elif quantize == "bitsandbytes": - warn_deprecate_bnb() - linear = Linear8bitLt( - weight, - bias, - has_fp16_weights=False, - threshold=6.0, - ) - if bias is not None: - linear.bias = nn.Parameter(bias) - elif quantize == "bitsandbytes-fp4": - linear = Linear4bit( - weight, - bias, - quant_type="fp4", - ) - elif quantize == "bitsandbytes-nf4": - linear = Linear4bit( - weight, - bias, - quant_type="nf4", - ) - elif quantize == "gptq": - try: - qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight - except Exception: - raise NotImplementedError( - f"The passed weight is not `gptq` compatible, loader needs to be updated." - ) - - if use_exllama: - linear = ExllamaQuantLinear( - qweight, qzeros, scales, g_idx, bias, bits, groupsize - ) - else: - linear = QuantLinear( - qweight, - qzeros, - scales, - g_idx, - bias, - bits, - groupsize, - ) - elif quantize == "awq": - try: - qweight, qzeros, scales, _, bits, groupsize, _ = weight - except Exception: - raise NotImplementedError( - f"The passed weight is not `awq` compatible, loader needs to be updated." - ) - if IS_ROCM_SYSTEM: - raise NotImplementedError( - "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " - "to use Exllama/GPTQ kernels for AWQ inference." - ) - if not HAS_AWQ: - raise NotImplementedError( - "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" - ) - linear = WQLinear( - w_bit=bits, - group_size=groupsize, - qweight=qweight, - qzeros=qzeros, - scales=scales, - bias=bias is not None, - ) - else: - raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") - return linear - - -class SuperLayer(nn.Module): - def __init__(self, linear): - super().__init__() - self.linear = linear - - def forward(self, x): - return self.linear.forward(x) - - -class ResBlock(torch.nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.linear = FastLinear.load( - config, prefix=f"{prefix}.linear", weights=weights, bias=True - ) - self.act = torch.nn.SiLU() - - def forward(self, x): - return x + self.act(self.linear(x)) - - -class MedusaModel(torch.nn.Module): - def __init__(self, config, medusa_config, weights): - super().__init__() - self.heads = torch.nn.ModuleList( - [ - MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights) - for i in range(get_speculate()) - ] - ) - - def forward(self, x): - speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) - return speculative_logits - - -class MedusaHead(torch.nn.Module): - def __init__(self, config, medusa_config, prefix, weights): - super().__init__() - self.blocks = torch.nn.ModuleList( - [ - ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) - for i in range(medusa_config["medusa_num_layers"]) - ] - ) - n = len(self.blocks) - self.out = FastLinear.load( - config, prefix=f"{prefix}.{n}", weights=weights, bias=False - ) - - def forward(self, x): - for block in self.blocks: - x = block(x) - x = self.out(x) - return x - - -class MedusaHeadV1(nn.Module): - def __init__(self, lm_head, medusa): - super().__init__() - self.lm_head = lm_head - self.medusa = medusa - - @staticmethod - def load(config, prefix: str, weights): - from pathlib import Path - from safetensors import safe_open - import json - - use_medusa = config.use_medusa - - medusa_config = str(Path(use_medusa) / "config.json") - filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") - - with open(medusa_config, "r") as f: - medusa_config = json.load(f) - routing = weights.routing - with safe_open(filename, framework="pytorch") as f: - for k in f.keys(): - if k in routing and routing[k] != filename: - raise RuntimeError( - f"Key {k} was found in multiple files: {filename} and {routing[k]}" - ) - routing[k] = filename - - medusa = MedusaModel(config, medusa_config, weights) - lm_head = TensorParallelHead.load(config, prefix, weights) - return MedusaHeadV1(lm_head, medusa) - - def forward( - self, input: torch.Tensor - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - logits = self.lm_head(input) - # If we have too many tokens, we skip speculative logits - if input.shape[0] > 128: - return logits, None - - speculative_logits = self.medusa(input) - return logits, speculative_logits - - -class MedusaHeadV2(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - from pathlib import Path - from safetensors import safe_open - import json - - use_medusa = config.use_medusa - - medusa_config = str(Path(use_medusa) / "config.json") - filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") - - with open(medusa_config, "r") as f: - medusa_config = json.load(f) - routing = weights.routing - with safe_open(filename, framework="pytorch") as f: - for k in f.keys(): - if k in routing and routing[k] != filename: - raise RuntimeError( - f"Key {k} was found in multiple files: {filename} and {routing[k]}" - ) - routing[k] = filename - - self.n_medusa_heads = get_speculate() - - assert medusa_config["medusa_num_layers"] == 1 - self.linear = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)], - dim=0, - weights=weights, - bias=True, - ) - self.process_group = weights.process_group - self.world_size = self.process_group.size() - self.rank = self.process_group.rank() - - self.act = torch.nn.SiLU() - - self.lm_head = TensorParallelHead.load(config, prefix, weights) - - def forward(self, x): - # If we have too many tokens, we skip speculative logits - if x.shape[0] > 128: - logits = self.lm_head(x) - return logits, None - - size = x.shape[-1] - block_size = (size + self.world_size - 1) // self.world_size - start = self.rank * block_size - stop = (self.rank + 1) * block_size - - x_block = x[:, start:stop] - - # Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1 - medusa_res = self.act(self.linear(x)).reshape( - *x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1] - ) - - # Apply all residual medusa heads - output = x[:, start:stop].unsqueeze(-2) + medusa_res - - # Gather medusa heads - world_output = [ - torch.empty_like(output) for _ in range(self.process_group.size()) - ] - torch.distributed.all_gather(world_output, output, group=self.process_group) - world_output = torch.cat(world_output, dim=-1) - - # Stack x and medusa residual x - stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2) - - # Compute lm head on x + medusa residual x - logits = self.lm_head(stacked_x) - - # Finally, split logits from speculative logits - logits, speculative_logits = torch.split( - logits, [1, self.n_medusa_heads], dim=-2 - ) - # Squeeze added dimension - logits = logits.squeeze(-2) - - return logits, speculative_logits - - -class SpeculativeHead(nn.Module): - def __init__(self, lm_head, medusa): - super().__init__() - self.head = lm_head - self.medusa = medusa - - @staticmethod - def load(config, prefix: str, weights): - use_medusa = config.use_medusa - if use_medusa: - lm_head = None - try: - medusa = MedusaHeadV1.load(config, prefix, weights) - except: - medusa = MedusaHeadV2(config, prefix, weights) - else: - lm_head = TensorParallelHead.load(config, prefix, weights) - medusa = None - return SpeculativeHead(lm_head, medusa) - - def forward( - self, input: torch.Tensor - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if self.medusa is not None: - return self.medusa(input) - - assert self.head is not None - logits = self.head(input) - return logits, None - - -class TensorParallelHead(SuperLayer): - def __init__(self, linear, process_group, should_gather: bool): - super().__init__(linear) - self.process_group = process_group - self.should_gather = should_gather - - @staticmethod - def load(config, prefix: str, weights): - if weights.process_group.size() > 1: - try: - weight = weights.get_sharded(f"{prefix}.weight", dim=0) - should_gather = True - except AssertionError: - # If the vocab size is not divisible by number of shards - # just load the entire thing. - weight = weights.get_tensor(f"{prefix}.weight") - should_gather = False - else: - weight = weights.get_tensor(f"{prefix}.weight") - should_gather = False - - # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings) - if config.quantize in ["gptq", "awq", "eetq"]: - quantize = None - else: - quantize = config.quantize - return TensorParallelHead( - get_linear(weight, bias=None, quantize=quantize), - process_group=weights.process_group, - should_gather=should_gather, - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - if not self.should_gather: - return super().forward(input) - - world_size = self.process_group.size() - if len(input.shape) == 2 and isinstance(self.linear, FastLinear): - out_dim = self.linear.weight.shape[0] - - if input.shape[0] == 1: - world_out = input.new_empty(1, out_dim * world_size) - local_out = input.new_empty(1, out_dim) - gather_input = local_out - else: - world_out = input.new_empty(out_dim * world_size, input.shape[0]) - gather_input = input.new_empty(out_dim, input.shape[0]) - local_out = gather_input.T - - torch.mm(input, self.linear.weight.T, out=local_out) - - torch.distributed.all_gather_into_tensor( - world_out, gather_input, group=self.process_group - ) - - if input.shape[0] == 1: - return world_out - return world_out.T - - output = super().forward(input) - world_output = [ - torch.empty_like(output) for _ in range(self.process_group.size()) - ] - torch.distributed.all_gather(world_output, output, group=self.process_group) - world_output = torch.cat(world_output, dim=-1) - return world_output - - -class TensorParallelColumnLinear(SuperLayer): - @classmethod - def load_gate_up(cls, config, prefix: str, weights, bias: bool): - """Specific method when the QKV was joined after the fact""" - weight = weights.get_weights_col_packed_gate_up( - prefix, quantize=config.quantize - ) - if bias: - raise NotImplementedError("packed_gate_up only implemented without bias") - else: - bias = None - linear = get_linear(weight, bias, config.quantize) - return cls(linear) - - @classmethod - def load_qkv(cls, config, prefix: str, weights, bias: bool): - """Specific method when the QKV was joined after the fact""" - weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize) - if bias: - raise NotImplementedError("packed_qkv only implemented for baichuan") - else: - bias = None - linear = get_linear(weight, bias, config.quantize) - return cls(linear) - - @classmethod - def load(cls, config, prefix: str, weights, bias: bool): - return cls.load_multi(config, [prefix], weights, bias, dim=0) - - @classmethod - def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): - weight = weights.get_multi_weights_col( - prefixes, quantize=config.quantize, dim=dim - ) - - if bias: - b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] - bias = torch.cat(b, dim=dim) - else: - bias = None - linear = get_linear(weight, bias, config.quantize) - return cls(linear) - - -class TensorParallelRowLinear(SuperLayer): - def __init__(self, linear, process_group): - super().__init__(linear) - self.process_group = process_group - - @classmethod - def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) - - if bias and weights.process_group.rank() == 0: - # Rank is only on the first rank process - bias = weights.get_tensor(f"{prefix}.bias") - else: - bias = None - return cls( - get_linear(weight, bias, config.quantize), - process_group=weights.process_group, - ) - - def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: - out = super().forward(input) - if self.process_group.size() > 1 and reduce: - torch.distributed.all_reduce(out, group=self.process_group) - return out - - -class TensorParallelEmbedding(nn.Module): - def __init__(self, prefix: str, weights, reduce=True): - super().__init__() - weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) - num_embeddings = weights.get_shape(f"{prefix}.weight")[0] - - process_group = weights.process_group - - world_size = process_group.size() - rank = process_group.rank() - - block_size = (num_embeddings + world_size - 1) // world_size - self.min_id = rank * block_size - self.max_id = min(num_embeddings, (rank + 1) * block_size) - self.null_idx = weight.shape[ - 0 - ] # Usually block_size, might be less in non even vocab_size. - self.process_group = weights.process_group - self.reduce = reduce - - """Additional 0 entry used for masking""" - self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1))) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - # default all out of bounds values to `self.null_idx` that will then be mapped to 0 - # translate for [0, self.max_id - self.min_id[ - input = torch.where( - (self.min_id > input) | (input >= self.max_id), - self.null_idx, - input - self.min_id, - ) - out = torch.nn.functional.embedding(input, self.weight) - if self.reduce and self.process_group.size() > 1: - torch.distributed.all_reduce(out, group=self.process_group) - return out - - -try: - if IS_CUDA_SYSTEM: - import dropout_layer_norm - elif IS_ROCM_SYSTEM: - from vllm import layernorm_ops - else: - dropout_layer_norm = None - - class FastLayerNorm(nn.LayerNorm): - def forward(self, hidden_states, residual=None): - if IS_XPU_SYSTEM: - res_out = hidden_states - out = ipex.llm.functional.add_layer_norm( - residual, hidden_states, self.weight, self.bias, self.eps, True - ) - if residual is not None: - res_out = residual - return out, res_out - elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: - if residual is not None: - hidden_states += residual - residual = hidden_states - - return super(FastLayerNorm, self).forward(hidden_states), residual - else: - ( - normed_hidden_states, - residual, - *rest, - ) = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - self.bias, - None, - None, - None, - None, - 0.0, - self.eps, - 1.0, - 0, - None, - False, - False, - ) - if residual is None: - residual = hidden_states - - return normed_hidden_states, residual - - class FastRMSNorm(nn.Module): - def __init__(self, weight: torch.Tensor, eps: float): - super().__init__() - - self.weight = nn.Parameter(weight) - self.variance_epsilon = eps - - @classmethod - def load(cls, prefix, weights, eps=1e-6): - weight = weights.get_tensor(f"{prefix}.weight") - return cls(weight, eps) - - def forward(self, hidden_states, residual=None): - if IS_XPU_SYSTEM: - residual_out = hidden_states - out = ipex.llm.functional.add_rms_norm( - residual, - hidden_states, - self.weight, - None, - self.variance_epsilon, - True, - ) - if residual is not None: - residual_out = residual - return out, residual_out - elif hidden_states.shape[-1] > 8192: - if residual is not None: - hidden_states += residual - residual = hidden_states - - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt( - variance + self.variance_epsilon - ) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states, residual - elif IS_CUDA_SYSTEM: - # faster post attention rms norm - ( - normed_hidden_states, - res, - *rest, - ) = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - None, - None, - None, - None, - None, - 0.0, - self.variance_epsilon, - 1.0, - 0, - None, - False, - True, # Activate RMSNorm - ) - if res is None: - res = hidden_states - - return normed_hidden_states, res - elif IS_ROCM_SYSTEM: - # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. - if residual is not None: - hidden_states += residual - residual = hidden_states - - out = torch.empty_like(hidden_states) - layernorm_ops.rms_norm( - out, - hidden_states, - self.weight.data, - self.variance_epsilon, - ) - return out, residual - else: - raise ValueError( - "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." - ) - -except ImportError: - pass - -try: - if IS_CUDA_SYSTEM: - from flash_attn.layers.rotary import RotaryEmbedding - import rotary_emb - elif IS_ROCM_SYSTEM: - from vllm import pos_encoding_ops - - def _create_inv_freq(dim, base, device): - inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) - ) - return inv_freq - - def _get_rope_config(config): - if os.getenv("ROPE_SCALING", None) is not None: - rope_scaling = { - "type": os.environ["ROPE_SCALING"], - "factor": float(os.environ["ROPE_FACTOR"]), - } - return rope_scaling - return getattr(config, "rope_scaling", None) - - class PositionRotaryEmbedding(nn.Module): - def __init__(self, inv_freq, scaling_factor): - super().__init__() - self.inv_freq = inv_freq - self._seq_len_cached = 0 - self._cos_cached = None - self._sin_cached = None - self._cos_k_cached = None - self._sin_k_cached = None - self.scaling_factor = scaling_factor - self.dynamic_args = None - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - ): - # Such controlflows may add some overhead. - if IS_CUDA_SYSTEM: - rotary_dim = cos.shape[-1] - q1 = query[..., :rotary_dim] - q2 = query[..., rotary_dim : 2 * rotary_dim] - - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) - - k1 = key[..., :rotary_dim] - k2 = key[..., rotary_dim : 2 * rotary_dim] - - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - elif IS_ROCM_SYSTEM: - # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. - # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 - - head_size = query.shape[-1] - - # Inplace operation, updating query and key. - pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True) - elif IS_XPU_SYSTEM: - ipex.llm.functional.rotary_embedding( - query, key, sin, cos, query.size(-1), True - ) - else: - raise ValueError( - "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." - ) - - @classmethod - def static(cls, config, dim, base, device): - inv_freq = _create_inv_freq(dim, base, device) - scaling_factor = None - rope_scaling = _get_rope_config(config) - if rope_scaling is not None: - if rope_scaling["type"] == "linear": - pass - elif rope_scaling["type"] == "dynamic": - scaling_factor = rope_scaling["factor"] - return DynamicPositionRotaryEmbedding( - dim=dim, - max_position_embeddings=config.max_position_embeddings, - base=base, - device=inv_freq.device, - scaling_factor=scaling_factor, - ) - elif rope_scaling["type"] == "yarn": - scaling_factor = rope_scaling["factor"] - return YarnPositionRotaryEmbedding( - dim=2 * inv_freq.shape[0], - max_position_embeddings=rope_scaling[ - "original_max_position_embeddings" - ], - base=10000.0, - device=inv_freq.device, - scaling_factor=scaling_factor, - extrapolation_factor=1, - attn_factor=1, - beta_fast=32, - beta_slow=1, - ) - elif rope_scaling["type"] == "su": - short_factor = torch.tensor( - rope_scaling["short_factor"], dtype=torch.float32, device=device - ) - short_inv_freq = 1.0 / ( - short_factor - * base - ** ( - torch.arange(0, dim, 2, device=device, dtype=torch.float32) - / dim - ) - ) - long_factor = torch.tensor( - rope_scaling["long_factor"], dtype=torch.float32, device=device - ) - long_inv_freq = 1.0 / ( - long_factor - * base - ** ( - torch.arange(0, dim, 2, device=device, dtype=torch.float32) - / dim - ) - ) - - original_max_position_embeddings = ( - config.original_max_position_embeddings - ) - max_position_embeddings = config.max_position_embeddings - if max_position_embeddings <= original_max_position_embeddings: - scaling_factor = 1.0 - else: - scale = ( - max_position_embeddings / original_max_position_embeddings - ) - scaling_factor = math.sqrt( - 1 - + math.log(scale) - / math.log(original_max_position_embeddings) - ) - - return SuRotaryEmbedding( - short_inv_freq=short_inv_freq, - long_inv_freq=long_inv_freq, - scaling_factor=scaling_factor, - original_max_position_embeddings=original_max_position_embeddings, - ) - else: - raise NotImplementedError( - f"rope scaling type {rope_scaling['type']} is not implemented or invalid" - ) - return cls(inv_freq, scaling_factor) - - @classmethod - def load(cls, config, prefix, weights): - # XXX: Always load this in float32 ! - dtype = weights.dtype - weights.dtype = torch.float32 - inv_freq = weights.get_tensor(f"{prefix}.inv_freq") - weights.dtype = dtype - - scaling_factor = None - rope_scaling = _get_rope_config(config) - if rope_scaling is not None: - scaling_factor = rope_scaling["factor"] - if rope_scaling["type"] == "linear": - pass - elif rope_scaling["type"] == "dynamic": - return DynamicPositionRotaryEmbedding( - dim=2 * inv_freq.shape[0], - max_position_embeddings=config.max_position_embeddings, - base=10000.0, - device=inv_freq.device, - scaling_factor=scaling_factor, - ) - elif rope_scaling["type"] == "yarn": - return YarnPositionRotaryEmbedding( - dim=2 * inv_freq.shape[0], - max_position_embeddings=rope_scaling[ - "original_max_position_embeddings" - ], - base=10000.0, - device=inv_freq.device, - scaling_factor=scaling_factor, - extrapolation_factor=1, - attn_factor=1, - beta_fast=32, - beta_slow=1, - ) - else: - raise NotImplementedError( - f"rope scaling type {rope_scaling['type']} is not implemented or invalid" - ) - return cls(inv_freq, scaling_factor) - - def _update_cos_sin_cache(self, dtype, device, seqlen): - # Reset the tables if the sequence length has changed, - # or if we're on a new device (possibly due to tracing for instance) - if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - ): - self._seq_len_cached = seqlen - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) - if self.scaling_factor is not None: - t /= self.scaling_factor - # Don't do einsum, it converts fp32 to fp16 - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - - freqs = torch.outer(t, self.inv_freq.to(device=t.device)) - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - - def get_cos_sin( - self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype - ): - """ - Return cos and sin for the asked position ids - """ - if IS_ROCM_SYSTEM: - # For RoCm, we always use float cos/sin to avoid a cast. - # For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26 - # But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal. - dtype = torch.float32 - - self._update_cos_sin_cache(dtype, position_ids.device, max_s) - - cos = torch.index_select(self._cos_cached, 0, position_ids) - sin = torch.index_select(self._sin_cached, 0, position_ids) - - # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. - return cos.unsqueeze(1), sin.unsqueeze(1) - - class SuRotaryEmbedding(PositionRotaryEmbedding): - def __init__( - self, - short_inv_freq, - long_inv_freq, - scaling_factor, - original_max_position_embeddings, - ): - super(PositionRotaryEmbedding, self).__init__() - self.short_inv_freq = short_inv_freq - self.long_inv_freq = long_inv_freq - self.scaling_factor = scaling_factor - self.original_max_position_embeddings = original_max_position_embeddings - self._seq_len_cached = 0 - self._cos_cached = None - self._sin_cached = None - self._cos_k_cached = None - self._sin_k_cached = None - self.dynamic_args = None - - def _update_cos_sin_cache(self, dtype, device, seqlen): - # Reset the tables if the sequence length has changed, - # or if we're on a new device (possibly due to tracing for instance) - if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - ): - self._seq_len_cached = seqlen - if seqlen > self.original_max_position_embeddings: - inv_freq = self.long_inv_freq - else: - inv_freq = self.short_inv_freq - t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype) - if self.scaling_factor is not None: - t /= self.scaling_factor - # Don't do einsum, it converts fp32 to fp16 - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - - freqs = torch.outer(t, inv_freq.to(device=t.device)) - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - - class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): - def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): - inv_freq = _create_inv_freq(dim, base, device) - super().__init__(inv_freq, scaling_factor) - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - - def _update_cos_sin_cache(self, dtype, device, seqlen): - # Reset the tables if the sequence length has changed, - # or if we're on a new device (possibly due to tracing for instance) - if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - ): - if seqlen > self.max_position_embeddings: - newbase = self.base * ( - (self.scaling_factor * seqlen / self.max_position_embeddings) - - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - self.inv_freq = _create_inv_freq( - self.dim, newbase, self.inv_freq.device - ) - self._seq_len_cached = seqlen - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) - # Don't do einsum, it converts fp32 to fp16 - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - - freqs = torch.outer(t, self.inv_freq.to(device=t.device)) - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - - # Inverse dim formula to find dim based on number of rotations - import math - - def find_correction_dim( - num_rotations, dim, base=10000, max_position_embeddings=2048 - ): - return ( - dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) - ) / (2 * math.log(base)) - - # Find dim range bounds based on rotations - def find_correction_range( - low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 - ): - low = math.floor( - find_correction_dim(low_rot, dim, base, max_position_embeddings) - ) - high = math.ceil( - find_correction_dim(high_rot, dim, base, max_position_embeddings) - ) - return max(low, 0), min(high, dim - 1) # Clamp values just in case - - def linear_ramp_mask(min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - def get_mscale(scale=1): - if scale <= 1: - return 1.0 - return 0.1 * math.log(scale) + 1.0 - - class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): - def __init__( - self, - dim, - max_position_embeddings, - base, - device, - scaling_factor, - *, - extrapolation_factor, - attn_factor, - beta_fast, - beta_slow, - ): - inv_freq = _create_inv_freq(dim, base, device) - super().__init__(inv_freq, scaling_factor) - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.extrapolation_factor = extrapolation_factor - self.attn_factor = attn_factor - self.beta_fast = beta_fast - self.beta_slow = beta_slow - self.mscale = float( - get_mscale(self.scaling_factor) * self.attn_factor - ) # Get n-d magnitude scaling corrected for interpolation - - def _update_cos_sin_cache(self, dtype, device, seqlen): - # Reset the tables if the sequence length has changed, - # or if we're on a new device (possibly due to tracing for instance) - if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - ): - if seqlen > self.max_position_embeddings: - inv_freq_extrapolation = _create_inv_freq( - self.dim, self.base, self.inv_freq.device - ) - freqs = 1.0 / inv_freq_extrapolation - inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs) - low, high = find_correction_range( - self.beta_fast, - self.beta_slow, - self.dim, - self.base, - self.max_position_embeddings, - ) - inv_freq_mask = ( - 1 - - linear_ramp_mask(low, high, self.dim // 2).float().to(device) - ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation - inv_freq = ( - inv_freq_interpolation * (1 - inv_freq_mask) - + inv_freq_extrapolation * inv_freq_mask - ) - - self.inv_freq = inv_freq - self.mscale = float( - get_mscale(self.scaling_factor) * self.attn_factor - ) # Get n-d magnitude scaling corrected for interpolation - - self._seq_len_cached = seqlen - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) - # Don't do einsum, it converts fp32 to fp16 - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - - freqs = torch.outer(t, self.inv_freq.to(device=t.device)) - self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype) - self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) - -except ImportError: - pass diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index 62c0c893..1b31f7e7 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -1,13 +1,9 @@ import torch -from text_generation_server.utils.import_utils import ( - IS_CUDA_SYSTEM, - IS_ROCM_SYSTEM, - IS_XPU_SYSTEM, -) +from text_generation_server.utils.import_utils import SYSTEM _PARTITION_SIZE = 512 -if IS_XPU_SYSTEM: +if SYSTEM == "xpu": import intel_extension_for_pytorch as ipex @@ -18,17 +14,17 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - if IS_CUDA_SYSTEM: + if SYSTEM == "cuda": from vllm._C import cache_ops cache_ops.reshape_and_cache( key, value, key_cache, value_cache, slots, "auto", 1.0 ) - elif IS_ROCM_SYSTEM: + elif SYSTEM == "rocm": from vllm import cache_ops cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots) - elif IS_XPU_SYSTEM: + elif SYSTEM == "xpu": ipex.llm.modules.PagedAttention.reshape_and_cache( key, value, key_cache, value_cache, slots ) @@ -68,7 +64,7 @@ def attention( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - if IS_XPU_SYSTEM: + if SYSTEM == "xpu": query = query.contiguous() return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, @@ -91,7 +87,7 @@ def attention( # to parallelize. use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) if use_v1: - if IS_CUDA_SYSTEM: + if SYSTEM == "cuda": from vllm._C import ops ops.paged_attention_v1( @@ -109,7 +105,7 @@ def attention( "auto", 1.0, ) - elif IS_ROCM_SYSTEM: + elif SYSTEM == "rocm": from vllm import attention_ops attention_ops.paged_attention_v1( @@ -143,7 +139,7 @@ def attention( ) max_logits = torch.empty_like(exp_sums) - if IS_CUDA_SYSTEM: + if SYSTEM == "cuda": from vllm._C import ops ops.paged_attention_v2( @@ -164,7 +160,7 @@ def attention( "auto", 1.0, ) - elif IS_ROCM_SYSTEM: + elif SYSTEM == "rocm": from vllm import attention_ops attention_ops.paged_attention_v2( diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index da7aed1a..6af7d3fb 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -171,7 +171,7 @@ class Weights: log_once( logger.info, "Converting AWQ model to Exllama/GPTQ packing format." ) - from text_generation_server.utils.awq.conversion_utils import ( + from text_generation_server.layers.awq.conversion_utils import ( fast_awq_to_gptq, ) @@ -227,7 +227,7 @@ class Weights: bits, groupsize, desc_act, quant_method = self._get_gptq_params() - from text_generation_server.utils.layers import HAS_EXLLAMA + from text_generation_server.layers.gptq import HAS_EXLLAMA use_exllama = ( bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act @@ -242,7 +242,7 @@ class Weights: log_once( logger.info, "Converting AWQ model to Exllama/GPTQ packing format." ) - from text_generation_server.utils.awq.conversion_utils import ( + from text_generation_server.layers.awq.conversion_utils import ( fast_awq_to_gptq, ) @@ -321,7 +321,7 @@ class Weights: # it would require to reorder input activations that are split unto several GPUs use_exllama = False - from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA + from text_generation_server.layers.gptq import HAS_EXLLAMA, CAN_EXLLAMA if use_exllama: if not HAS_EXLLAMA: @@ -348,7 +348,7 @@ class Weights: log_once( logger.info, "Converting AWQ model to Exllama/GPTQ packing format." ) - from text_generation_server.utils.awq.conversion_utils import ( + from text_generation_server.layers.awq.conversion_utils import ( fast_awq_to_gptq, )