parent
73a4d65d26
commit
37df6df38e
|
@ -147,8 +147,10 @@ def serve(
|
|||
# This will allocate those buffers.
|
||||
from text_generation_server.utils.gptq.exllama import (
|
||||
create_exllama_buffers,
|
||||
set_device,
|
||||
)
|
||||
|
||||
set_device(model.device)
|
||||
create_exllama_buffers()
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
|
@ -32,9 +32,16 @@ TEMP_STATE = None
|
|||
TEMP_DQ = None
|
||||
|
||||
|
||||
def set_device(device):
|
||||
global DEVICE
|
||||
DEVICE = device
|
||||
|
||||
|
||||
def create_exllama_buffers():
|
||||
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
|
||||
|
||||
assert DEVICE is not None, "call set_device first"
|
||||
|
||||
if ACT_ORDER:
|
||||
# TODO: this should be set to rust side `max_total_tokens`, but TGI
|
||||
# does not offer an API to expose this variable to python, as this variable
|
||||
|
|
Loading…
Reference in New Issue