feat: prefer model id in request
This commit is contained in:
parent
de56a81c5c
commit
68399c1ae3
|
@ -233,8 +233,9 @@ class FlashCausalLMBatch(Batch):
|
|||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
|
||||
adapter_indices_list.append(torch.full((input_length,), r.adapter_index))
|
||||
adapter_set.add(r.adapter_index)
|
||||
adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
||||
adapter_indices_list.append(torch.full((input_length,), adapter_index))
|
||||
adapter_set.add(adapter_index)
|
||||
|
||||
# Paged attention
|
||||
# Remove one as the first token des not have a past
|
||||
|
@ -498,7 +499,10 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
top_n_tokens.append(self.top_n_tokens[idx])
|
||||
|
||||
adapter_set.add(self.requests[idx].adapter_index)
|
||||
adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(
|
||||
self.requests[idx].adapter_id, 0
|
||||
)
|
||||
adapter_set.add(adapter_index)
|
||||
|
||||
remaining_tokens = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
|
|
|
@ -25,3 +25,14 @@ MODEL_ID = None
|
|||
def set_model_id(model_id: str):
|
||||
global MODEL_ID
|
||||
MODEL_ID = model_id
|
||||
|
||||
|
||||
# NOTE: eventually we should move this into the router and pass back the
|
||||
# index in all cases.
|
||||
global ADAPTER_TO_INDEX
|
||||
ADAPTER_TO_INDEX = None
|
||||
|
||||
|
||||
def set_adapter_to_index(adapter_to_index: dict):
|
||||
global ADAPTER_TO_INDEX
|
||||
ADAPTER_TO_INDEX = adapter_to_index
|
||||
|
|
|
@ -29,7 +29,7 @@ except (ImportError, NotImplementedError):
|
|||
|
||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.models.globals import set_model_id
|
||||
from text_generation_server.models.globals import set_model_id, set_adapter_to_index
|
||||
from text_generation_server.utils.adapter import (
|
||||
AdapterParameters,
|
||||
)
|
||||
|
@ -216,6 +216,7 @@ def serve(
|
|||
trust_remote_code: bool = False,
|
||||
):
|
||||
unix_socket_template = "unix://{}-{}"
|
||||
adapter_to_index = {}
|
||||
if sharded:
|
||||
server_urls = [
|
||||
unix_socket_template.format(uds_path, rank)
|
||||
|
@ -251,6 +252,7 @@ def serve(
|
|||
majority_sign_method=0,
|
||||
)
|
||||
adapter_index = index
|
||||
adapter_to_index[adapter_id] = adapter_index
|
||||
model.load_adapter(
|
||||
adapter_parameters,
|
||||
None, # adapter_source
|
||||
|
@ -263,6 +265,7 @@ def serve(
|
|||
logger.exception("Error when initializing model")
|
||||
raise
|
||||
|
||||
set_adapter_to_index(adapter_to_index)
|
||||
server = aio.server(
|
||||
interceptors=[
|
||||
ExceptionInterceptor(),
|
||||
|
|
Loading…
Reference in New Issue