From f7f61876cff78d934b96cb80a8b312d5f9600802 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 31 Jul 2024 10:27:15 -0400 Subject: [PATCH] Pr 2290 ci run (#2329) * MODEL_ID propagation fix * fix: remove global model id --------- Co-authored-by: root --- server/text_generation_server/models/flash_causal_lm.py | 3 +-- server/text_generation_server/models/globals.py | 9 --------- server/text_generation_server/server.py | 3 +-- 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 152516e7..36bb2662 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -43,7 +43,6 @@ from text_generation_server.models.globals import ( BLOCK_SIZE, CUDA_GRAPHS, get_adapter_to_index, - MODEL_ID, ) from text_generation_server.layers.attention import Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser @@ -1156,7 +1155,7 @@ class FlashCausalLM(Model): tunableop_filepath = os.path.join( HUGGINGFACE_HUB_CACHE, - f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", + f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", ) log_master( diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index ac42df30..8d2431db 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -29,15 +29,6 @@ if cuda_graphs is not None: CUDA_GRAPHS = cuda_graphs -# This is overridden at model loading. -MODEL_ID = None - - -def set_model_id(model_id: str): - global MODEL_ID - MODEL_ID = model_id - - # NOTE: eventually we should move this into the router and pass back the # index in all cases. ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 22bd759f..b92ab572 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -30,7 +30,7 @@ except (ImportError, NotImplementedError): from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor -from text_generation_server.models.globals import set_model_id, set_adapter_to_index +from text_generation_server.models.globals import set_adapter_to_index class SignalHandler: @@ -271,7 +271,6 @@ def serve( while signal_handler.KEEP_PROCESSING: await asyncio.sleep(0.5) - set_model_id(model_id) asyncio.run( serve_inner( model_id,