feat: prefer model id in request
This commit is contained in:
parent
de56a81c5c
commit
68399c1ae3
|
@ -233,8 +233,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
top_n_tokens.append(r.top_n_tokens)
|
top_n_tokens.append(r.top_n_tokens)
|
||||||
|
|
||||||
adapter_indices_list.append(torch.full((input_length,), r.adapter_index))
|
adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
||||||
adapter_set.add(r.adapter_index)
|
adapter_indices_list.append(torch.full((input_length,), adapter_index))
|
||||||
|
adapter_set.add(adapter_index)
|
||||||
|
|
||||||
# Paged attention
|
# Paged attention
|
||||||
# Remove one as the first token des not have a past
|
# Remove one as the first token des not have a past
|
||||||
|
@ -498,7 +499,10 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
top_n_tokens.append(self.top_n_tokens[idx])
|
top_n_tokens.append(self.top_n_tokens[idx])
|
||||||
|
|
||||||
adapter_set.add(self.requests[idx].adapter_index)
|
adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(
|
||||||
|
self.requests[idx].adapter_id, 0
|
||||||
|
)
|
||||||
|
adapter_set.add(adapter_index)
|
||||||
|
|
||||||
remaining_tokens = (
|
remaining_tokens = (
|
||||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
|
|
|
@ -25,3 +25,14 @@ MODEL_ID = None
|
||||||
def set_model_id(model_id: str):
|
def set_model_id(model_id: str):
|
||||||
global MODEL_ID
|
global MODEL_ID
|
||||||
MODEL_ID = model_id
|
MODEL_ID = model_id
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: eventually we should move this into the router and pass back the
|
||||||
|
# index in all cases.
|
||||||
|
global ADAPTER_TO_INDEX
|
||||||
|
ADAPTER_TO_INDEX = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_adapter_to_index(adapter_to_index: dict):
|
||||||
|
global ADAPTER_TO_INDEX
|
||||||
|
ADAPTER_TO_INDEX = adapter_to_index
|
||||||
|
|
|
@ -29,7 +29,7 @@ except (ImportError, NotImplementedError):
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
from text_generation_server.models.globals import set_model_id
|
from text_generation_server.models.globals import set_model_id, set_adapter_to_index
|
||||||
from text_generation_server.utils.adapter import (
|
from text_generation_server.utils.adapter import (
|
||||||
AdapterParameters,
|
AdapterParameters,
|
||||||
)
|
)
|
||||||
|
@ -216,6 +216,7 @@ def serve(
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
unix_socket_template = "unix://{}-{}"
|
unix_socket_template = "unix://{}-{}"
|
||||||
|
adapter_to_index = {}
|
||||||
if sharded:
|
if sharded:
|
||||||
server_urls = [
|
server_urls = [
|
||||||
unix_socket_template.format(uds_path, rank)
|
unix_socket_template.format(uds_path, rank)
|
||||||
|
@ -251,6 +252,7 @@ def serve(
|
||||||
majority_sign_method=0,
|
majority_sign_method=0,
|
||||||
)
|
)
|
||||||
adapter_index = index
|
adapter_index = index
|
||||||
|
adapter_to_index[adapter_id] = adapter_index
|
||||||
model.load_adapter(
|
model.load_adapter(
|
||||||
adapter_parameters,
|
adapter_parameters,
|
||||||
None, # adapter_source
|
None, # adapter_source
|
||||||
|
@ -263,6 +265,7 @@ def serve(
|
||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
set_adapter_to_index(adapter_to_index)
|
||||||
server = aio.server(
|
server = aio.server(
|
||||||
interceptors=[
|
interceptors=[
|
||||||
ExceptionInterceptor(),
|
ExceptionInterceptor(),
|
||||||
|
|
Loading…
Reference in New Issue