feat: prefer model id in request

This commit is contained in:
drbh 2024-06-06 23:21:10 +00:00
parent de56a81c5c
commit 68399c1ae3
3 changed files with 22 additions and 4 deletions

View File

@ -233,8 +233,9 @@ class FlashCausalLMBatch(Batch):
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens) top_n_tokens.append(r.top_n_tokens)
adapter_indices_list.append(torch.full((input_length,), r.adapter_index)) adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_set.add(r.adapter_index) adapter_indices_list.append(torch.full((input_length,), adapter_index))
adapter_set.add(adapter_index)
# Paged attention # Paged attention
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
@ -498,7 +499,10 @@ class FlashCausalLMBatch(Batch):
top_n_tokens.append(self.top_n_tokens[idx]) top_n_tokens.append(self.top_n_tokens[idx])
adapter_set.add(self.requests[idx].adapter_index) adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(
self.requests[idx].adapter_id, 0
)
adapter_set.add(adapter_index)
remaining_tokens = ( remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens stopping_criteria.max_new_tokens - stopping_criteria.current_tokens

View File

@ -25,3 +25,14 @@ MODEL_ID = None
def set_model_id(model_id: str): def set_model_id(model_id: str):
global MODEL_ID global MODEL_ID
MODEL_ID = model_id MODEL_ID = model_id
# NOTE: eventually we should move this into the router and pass back the
# index in all cases.
global ADAPTER_TO_INDEX
ADAPTER_TO_INDEX = None
def set_adapter_to_index(adapter_to_index: dict):
global ADAPTER_TO_INDEX
ADAPTER_TO_INDEX = adapter_to_index

View File

@ -29,7 +29,7 @@ except (ImportError, NotImplementedError):
from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.globals import set_model_id from text_generation_server.models.globals import set_model_id, set_adapter_to_index
from text_generation_server.utils.adapter import ( from text_generation_server.utils.adapter import (
AdapterParameters, AdapterParameters,
) )
@ -216,6 +216,7 @@ def serve(
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
unix_socket_template = "unix://{}-{}" unix_socket_template = "unix://{}-{}"
adapter_to_index = {}
if sharded: if sharded:
server_urls = [ server_urls = [
unix_socket_template.format(uds_path, rank) unix_socket_template.format(uds_path, rank)
@ -251,6 +252,7 @@ def serve(
majority_sign_method=0, majority_sign_method=0,
) )
adapter_index = index adapter_index = index
adapter_to_index[adapter_id] = adapter_index
model.load_adapter( model.load_adapter(
adapter_parameters, adapter_parameters,
None, # adapter_source None, # adapter_source
@ -263,6 +265,7 @@ def serve(
logger.exception("Error when initializing model") logger.exception("Error when initializing model")
raise raise
set_adapter_to_index(adapter_to_index)
server = aio.server( server = aio.server(
interceptors=[ interceptors=[
ExceptionInterceptor(), ExceptionInterceptor(),