From 8bd0adb1356d5c6e0238abe64a6cf94b50140db4 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 27 Jul 2023 12:28:10 +0200 Subject: [PATCH] fix(server): fix quantization python requirements (#708) --- server/poetry.lock | 14 +++++++++- server/pyproject.toml | 1 + server/requirements.txt | 1 + .../text_generation_server/models/flash_rw.py | 1 - server/text_generation_server/server.py | 26 +++++++++---------- .../utils/gptq/quantize.py | 8 ++---- 6 files changed, 30 insertions(+), 21 deletions(-) diff --git a/server/poetry.lock b/server/poetry.lock index 7d00f223..4349340f 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -624,6 +624,14 @@ python-versions = ">=3.8" [package.dependencies] mpmath = ">=0.19" +[[package]] +name = "texttable" +version = "1.6.7" +description = "module to create simple ASCII tables" +category = "main" +optional = false +python-versions = "*" + [[package]] name = "tokenizers" version = "0.13.3" @@ -810,7 +818,7 @@ bnb = ["bitsandbytes"] [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "65afc4bfa07da4b1427d269fa745939da3851eaede9a8478f5a4bf5949d32cc9" +content-hash = "c2e0d926748a7d420909c6bd21e17cf060bc7acdd788ae93e3ec1809a4b84529" [metadata.files] accelerate = [ @@ -1484,6 +1492,10 @@ sympy = [ {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, ] +texttable = [ + {file = "texttable-1.6.7-py2.py3-none-any.whl", hash = "sha256:b7b68139aa8a6339d2c320ca8b1dc42d13a7831a346b446cb9eb385f0c76310c"}, + {file = "texttable-1.6.7.tar.gz", hash = "sha256:290348fb67f7746931bcdfd55ac7584ecd4e5b0846ab164333f0794b121760f2"}, +] tokenizers = [ {file = "tokenizers-0.13.3-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:f3835c5be51de8c0a092058a4d4380cb9244fb34681fd0a295fbf0a52a5fdf33"}, {file = "tokenizers-0.13.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4ef4c3e821730f2692489e926b184321e887f34fb8a6b80b8096b966ba663d07"}, diff --git a/server/pyproject.toml b/server/pyproject.toml index be79da51..a0fc4411 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -28,6 +28,7 @@ tokenizers = "0.13.3" huggingface-hub = "^0.14.1" transformers = "4.29.2" einops = "^0.6.1" +texttable = "^1.6.7" [tool.poetry.extras] accelerate = ["accelerate"] diff --git a/server/requirements.txt b/server/requirements.txt index 9b8b2164..02954f24 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -35,6 +35,7 @@ requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0" safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0" setuptools==68.0.0 ; python_version >= "3.9" and python_version < "4.0" +texttable==1.6.7 ; python_version >= "3.9" and python_version < "4.0" tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0" tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0" transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0" diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 9e0080a9..7d655945 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -61,7 +61,6 @@ class FlashRWSharded(FlashCausalLM): if config.quantize == "gptq": weights._set_gptq_params(model_id) - model = FlashRWForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 7c2f1b35..1cedc151 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -105,21 +105,21 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def serve( - model_id: str, - revision: Optional[str], - sharded: bool, - quantize: Optional[str], - dtype: Optional[str], - trust_remote_code: bool, - uds_path: Path, + model_id: str, + revision: Optional[str], + sharded: bool, + quantize: Optional[str], + dtype: Optional[str], + trust_remote_code: bool, + uds_path: Path, ): async def serve_inner( - model_id: str, - revision: Optional[str], - sharded: bool = False, - quantize: Optional[str] = None, - dtype: Optional[str] = None, - trust_remote_code: bool = False, + model_id: str, + revision: Optional[str], + sharded: bool = False, + quantize: Optional[str] = None, + dtype: Optional[str] = None, + trust_remote_code: bool = False, ): unix_socket_template = "unix://{}-{}" if sharded: diff --git a/server/text_generation_server/utils/gptq/quantize.py b/server/text_generation_server/utils/gptq/quantize.py index 6eb44e41..3f8e897a 100644 --- a/server/text_generation_server/utils/gptq/quantize.py +++ b/server/text_generation_server/utils/gptq/quantize.py @@ -1,18 +1,14 @@ -import argparse import time -import numpy as np -import torch import torch.nn as nn import math import json import os +import torch +import transformers from texttable import Texttable from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer -import transformers from huggingface_hub import HfApi -import numpy as np -import torch from accelerate import init_empty_weights from text_generation_server.utils import initialize_torch_distributed, Weights from text_generation_server.utils.hub import weight_files