fix(server): fix quantization python requirements (#708)

This commit is contained in:
OlivierDehaene 2023-07-27 12:28:10 +02:00 committed by GitHub
parent e64a65891b
commit 8bd0adb135
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 30 additions and 21 deletions

14
server/poetry.lock generated
View File

@ -624,6 +624,14 @@ python-versions = ">=3.8"
[package.dependencies]
mpmath = ">=0.19"
[[package]]
name = "texttable"
version = "1.6.7"
description = "module to create simple ASCII tables"
category = "main"
optional = false
python-versions = "*"
[[package]]
name = "tokenizers"
version = "0.13.3"
@ -810,7 +818,7 @@ bnb = ["bitsandbytes"]
[metadata]
lock-version = "1.1"
python-versions = "^3.9"
content-hash = "65afc4bfa07da4b1427d269fa745939da3851eaede9a8478f5a4bf5949d32cc9"
content-hash = "c2e0d926748a7d420909c6bd21e17cf060bc7acdd788ae93e3ec1809a4b84529"
[metadata.files]
accelerate = [
@ -1484,6 +1492,10 @@ sympy = [
{file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"},
{file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"},
]
texttable = [
{file = "texttable-1.6.7-py2.py3-none-any.whl", hash = "sha256:b7b68139aa8a6339d2c320ca8b1dc42d13a7831a346b446cb9eb385f0c76310c"},
{file = "texttable-1.6.7.tar.gz", hash = "sha256:290348fb67f7746931bcdfd55ac7584ecd4e5b0846ab164333f0794b121760f2"},
]
tokenizers = [
{file = "tokenizers-0.13.3-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:f3835c5be51de8c0a092058a4d4380cb9244fb34681fd0a295fbf0a52a5fdf33"},
{file = "tokenizers-0.13.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4ef4c3e821730f2692489e926b184321e887f34fb8a6b80b8096b966ba663d07"},

View File

@ -28,6 +28,7 @@ tokenizers = "0.13.3"
huggingface-hub = "^0.14.1"
transformers = "4.29.2"
einops = "^0.6.1"
texttable = "^1.6.7"
[tool.poetry.extras]
accelerate = ["accelerate"]

View File

@ -35,6 +35,7 @@ requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
setuptools==68.0.0 ; python_version >= "3.9" and python_version < "4.0"
texttable==1.6.7 ; python_version >= "3.9" and python_version < "4.0"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0"

View File

@ -61,7 +61,6 @@ class FlashRWSharded(FlashCausalLM):
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
model = FlashRWForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)

View File

@ -105,21 +105,21 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def serve(
model_id: str,
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
model_id: str,
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
):
async def serve_inner(
model_id: str,
revision: Optional[str],
sharded: bool = False,
quantize: Optional[str] = None,
dtype: Optional[str] = None,
trust_remote_code: bool = False,
model_id: str,
revision: Optional[str],
sharded: bool = False,
quantize: Optional[str] = None,
dtype: Optional[str] = None,
trust_remote_code: bool = False,
):
unix_socket_template = "unix://{}-{}"
if sharded:

View File

@ -1,18 +1,14 @@
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
import math
import json
import os
import torch
import transformers
from texttable import Texttable
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import transformers
from huggingface_hub import HfApi
import numpy as np
import torch
from accelerate import init_empty_weights
from text_generation_server.utils import initialize_torch_distributed, Weights
from text_generation_server.utils.hub import weight_files