Upgrading exl2. (#2415)
* Upgrading exl2. * Fixing the other pathways. * Fix idefics.
This commit is contained in:
parent
c5fff92b48
commit
f3b5c69441
|
@ -9,7 +9,7 @@ backends/client/src/v3/pb
|
||||||
|
|
||||||
# ROCm auto-generated files
|
# ROCm auto-generated files
|
||||||
*.hip
|
*.hip
|
||||||
server/exllamav2_kernels/exllamav2_kernels/hip/
|
server/exllamav2
|
||||||
server/exllama_kernels/exllama_kernels/hip/
|
server/exllama_kernels/exllama_kernels/hip/
|
||||||
server/exllama_kernels/exllama_kernels/hip_func/
|
server/exllama_kernels/exllama_kernels/hip_func/
|
||||||
*_hip.cuh
|
*_hip.cuh
|
||||||
|
|
|
@ -123,10 +123,10 @@ RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
||||||
# Build Transformers exllama kernels
|
# Build Transformers exllama kernels
|
||||||
FROM kernel-builder AS exllamav2-kernels-builder
|
FROM kernel-builder AS exllamav2-kernels-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
COPY server/exllamav2_kernels/ .
|
COPY server/Makefile-exllamav2/ Makefile
|
||||||
|
|
||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-exllamav2
|
||||||
|
|
||||||
# Build Transformers awq kernels
|
# Build Transformers awq kernels
|
||||||
FROM kernel-builder AS awq-kernels-builder
|
FROM kernel-builder AS awq-kernels-builder
|
||||||
|
@ -221,7 +221,7 @@ COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /
|
||||||
# Copy build artifacts from exllama kernels builder
|
# Copy build artifacts from exllama kernels builder
|
||||||
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
# Copy build artifacts from exllamav2 kernels builder
|
# Copy build artifacts from exllamav2 kernels builder
|
||||||
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
# Copy build artifacts from awq kernels builder
|
# Copy build artifacts from awq kernels builder
|
||||||
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
# Copy build artifacts from eetq kernels builder
|
# Copy build artifacts from eetq kernels builder
|
||||||
|
|
|
@ -93,6 +93,7 @@
|
||||||
causal-conv1d
|
causal-conv1d
|
||||||
click
|
click
|
||||||
einops
|
einops
|
||||||
|
exllamav2
|
||||||
fbgemm-gpu
|
fbgemm-gpu
|
||||||
flashinfer
|
flashinfer
|
||||||
flash-attn
|
flash-attn
|
||||||
|
|
|
@ -6,6 +6,7 @@ include Makefile-eetq
|
||||||
include Makefile-selective-scan
|
include Makefile-selective-scan
|
||||||
include Makefile-lorax-punica
|
include Makefile-lorax-punica
|
||||||
include Makefile-fbgemm
|
include Makefile-fbgemm
|
||||||
|
include Makefile-exllamav2
|
||||||
|
|
||||||
unit-tests:
|
unit-tests:
|
||||||
pytest -s -vv -m "not private" tests
|
pytest -s -vv -m "not private" tests
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
exllamav2_commit := v0.1.8
|
||||||
|
|
||||||
|
build-exllamav2:
|
||||||
|
git clone https://github.com/turboderp/exllamav2.git exllamav2 && \
|
||||||
|
cd exllamav2 && git fetch && git checkout $(exllamav2_commit) && \
|
||||||
|
git submodule update --init --recursive && \
|
||||||
|
pip install -r requirements.txt && \
|
||||||
|
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py build
|
||||||
|
|
||||||
|
install-exllamav2: build-exllamav2
|
||||||
|
cd exllamav2/ && \
|
||||||
|
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py install
|
|
@ -12,7 +12,10 @@ from text_generation_server.layers.gptq import GPTQWeight
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
from exllamav2.ext import exllamav2_ext
|
||||||
|
|
||||||
|
make_q_matrix = exllamav2_ext.make_q_matrix
|
||||||
|
gemm_half_q_half = exllamav2_ext.gemm_half_q_half
|
||||||
except ImportError:
|
except ImportError:
|
||||||
log_master(logger.warning, "exllamav2_kernels not installed.")
|
log_master(logger.warning, "exllamav2_kernels not installed.")
|
||||||
raise
|
raise
|
||||||
|
@ -70,6 +73,10 @@ def ext_make_q_matrix(
|
||||||
"""
|
"""
|
||||||
Create Q matrix
|
Create Q matrix
|
||||||
"""
|
"""
|
||||||
|
# max_dq_size = 512*(1024**2)
|
||||||
|
# max_dq_rows = max_dq_size // out_features[0]
|
||||||
|
max_dq_rows = 0
|
||||||
|
|
||||||
# EXL2
|
# EXL2
|
||||||
if isinstance(w, Exl2Weight):
|
if isinstance(w, Exl2Weight):
|
||||||
extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0])
|
extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0])
|
||||||
|
@ -83,10 +90,12 @@ def ext_make_q_matrix(
|
||||||
w.q_scale_max,
|
w.q_scale_max,
|
||||||
w.q_groups,
|
w.q_groups,
|
||||||
extra.q_group_map,
|
extra.q_group_map,
|
||||||
none_tensor,
|
none_tensor, # zeros
|
||||||
none_tensor,
|
none_tensor, # scales
|
||||||
none_tensor,
|
none_tensor, # g_idx
|
||||||
|
none_tensor, # bias
|
||||||
temp_dq,
|
temp_dq,
|
||||||
|
max_dq_rows,
|
||||||
)
|
)
|
||||||
# GPTQ
|
# GPTQ
|
||||||
elif isinstance(w, GPTQWeight):
|
elif isinstance(w, GPTQWeight):
|
||||||
|
@ -106,29 +115,33 @@ def ext_make_q_matrix(
|
||||||
w.qweight,
|
w.qweight,
|
||||||
extra.q_perm,
|
extra.q_perm,
|
||||||
extra.q_invperm,
|
extra.q_invperm,
|
||||||
none_tensor,
|
none_tensor, # q_scale
|
||||||
none_tensor,
|
none_tensor, # q_scale_max
|
||||||
none_tensor,
|
none_tensor, # q_groups
|
||||||
none_tensor,
|
none_tensor, # q_group_map
|
||||||
w.qzeros,
|
w.qzeros,
|
||||||
w.scales,
|
w.scales,
|
||||||
w.g_idx.cpu(),
|
w.g_idx.cpu(),
|
||||||
|
none_tensor, # bias
|
||||||
temp_dq,
|
temp_dq,
|
||||||
|
max_dq_rows,
|
||||||
)
|
)
|
||||||
# GPTQ without g_idx
|
# GPTQ without g_idx
|
||||||
else:
|
else:
|
||||||
return make_q_matrix(
|
return make_q_matrix(
|
||||||
w.qweight,
|
w.qweight,
|
||||||
none_tensor,
|
none_tensor, # q_perm
|
||||||
none_tensor,
|
none_tensor, # q_invperm
|
||||||
none_tensor,
|
none_tensor, # q_scale
|
||||||
none_tensor,
|
none_tensor, # q_scale_max
|
||||||
none_tensor,
|
none_tensor, # q_groups
|
||||||
none_tensor,
|
none_tensor, # q_group_map
|
||||||
w.qzeros,
|
w.qzeros,
|
||||||
w.scales,
|
w.scales,
|
||||||
none_tensor,
|
none_tensor, # g_idx
|
||||||
|
none_tensor, # bias
|
||||||
temp_dq,
|
temp_dq,
|
||||||
|
max_dq_rows,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
RuntimeError("Cannot create handle")
|
RuntimeError("Cannot create handle")
|
||||||
|
|
|
@ -511,6 +511,7 @@ class CausalLM(Model):
|
||||||
config_class=AutoConfig,
|
config_class=AutoConfig,
|
||||||
batch_class=CausalLMBatch,
|
batch_class=CausalLMBatch,
|
||||||
):
|
):
|
||||||
|
self.quantize = quantize
|
||||||
self.batch_class = batch_class
|
self.batch_class = batch_class
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
|
|
@ -872,6 +872,7 @@ class FlashCausalLM(Model):
|
||||||
head_size: Optional[int] = None,
|
head_size: Optional[int] = None,
|
||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
):
|
):
|
||||||
|
self.quantize = quantize
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
|
|
@ -33,6 +33,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
self.quantize = quantize
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
|
|
@ -580,6 +580,7 @@ class IdeficsCausalLM(Model):
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
self.quantize = quantize
|
||||||
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
||||||
IdeficsForVisionText2Text,
|
IdeficsForVisionText2Text,
|
||||||
)
|
)
|
||||||
|
|
|
@ -553,6 +553,7 @@ class Seq2SeqLM(Model):
|
||||||
tokenizer_class=AutoTokenizer,
|
tokenizer_class=AutoTokenizer,
|
||||||
aliases=None,
|
aliases=None,
|
||||||
):
|
):
|
||||||
|
self.quantize = quantize
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
|
|
@ -50,12 +50,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
self,
|
self,
|
||||||
model: Model,
|
model: Model,
|
||||||
cache: Cache,
|
cache: Cache,
|
||||||
quantize: Optional[str],
|
|
||||||
server_urls: List[str],
|
server_urls: List[str],
|
||||||
):
|
):
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
self.model = model
|
self.model = model
|
||||||
self.quantize = quantize
|
# Quantize is resolved during model loading
|
||||||
|
self.quantize = model.quantize
|
||||||
self.server_urls = server_urls
|
self.server_urls = server_urls
|
||||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||||
if model.device.type == "cuda":
|
if model.device.type == "cuda":
|
||||||
|
@ -255,7 +255,7 @@ def serve(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
|
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
|
||||||
TextGenerationService(model, Cache(), quantize, server_urls), server
|
TextGenerationService(model, Cache(), server_urls), server
|
||||||
)
|
)
|
||||||
SERVICE_NAMES = (
|
SERVICE_NAMES = (
|
||||||
generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
|
generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
|
||||||
|
|
Loading…
Reference in New Issue