From 37df6df38edb4dc8eee89a42ec3791e89442c851 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Mon, 24 Jul 2023 14:25:43 +0200 Subject: [PATCH] fix(server): fix exllama buffers (#689) Close #683 --- server/text_generation_server/server.py | 28 ++++++++++--------- .../utils/gptq/exllama.py | 7 +++++ 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 0929b46f..7c2f1b35 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -105,21 +105,21 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def serve( - model_id: str, - revision: Optional[str], - sharded: bool, - quantize: Optional[str], - dtype: Optional[str], - trust_remote_code: bool, - uds_path: Path, -): - async def serve_inner( model_id: str, revision: Optional[str], - sharded: bool = False, - quantize: Optional[str] = None, - dtype: Optional[str] = None, - trust_remote_code: bool = False, + sharded: bool, + quantize: Optional[str], + dtype: Optional[str], + trust_remote_code: bool, + uds_path: Path, +): + async def serve_inner( + model_id: str, + revision: Optional[str], + sharded: bool = False, + quantize: Optional[str] = None, + dtype: Optional[str] = None, + trust_remote_code: bool = False, ): unix_socket_template = "unix://{}-{}" if sharded: @@ -147,8 +147,10 @@ def serve( # This will allocate those buffers. from text_generation_server.utils.gptq.exllama import ( create_exllama_buffers, + set_device, ) + set_device(model.device) create_exllama_buffers() except ImportError: pass diff --git a/server/text_generation_server/utils/gptq/exllama.py b/server/text_generation_server/utils/gptq/exllama.py index e89b725c..6a1cf117 100644 --- a/server/text_generation_server/utils/gptq/exllama.py +++ b/server/text_generation_server/utils/gptq/exllama.py @@ -32,9 +32,16 @@ TEMP_STATE = None TEMP_DQ = None +def set_device(device): + global DEVICE + DEVICE = device + + def create_exllama_buffers(): global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ + assert DEVICE is not None, "call set_device first" + if ACT_ORDER: # TODO: this should be set to rust side `max_total_tokens`, but TGI # does not offer an API to expose this variable to python, as this variable