Pr 2290 ci run (#2329)

* MODEL_ID propagation fix

* fix: remove global model id

---------

Co-authored-by: root <root@tw031.pit.tensorwave.lan>
This commit is contained in:
drbh 2024-07-31 10:27:15 -04:00 committed by GitHub
parent 34f7dcfd80
commit f7f61876cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 2 additions and 13 deletions

View File

@ -43,7 +43,6 @@ from text_generation_server.models.globals import (
BLOCK_SIZE, BLOCK_SIZE,
CUDA_GRAPHS, CUDA_GRAPHS,
get_adapter_to_index, get_adapter_to_index,
MODEL_ID,
) )
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
@ -1156,7 +1155,7 @@ class FlashCausalLM(Model):
tunableop_filepath = os.path.join( tunableop_filepath = os.path.join(
HUGGINGFACE_HUB_CACHE, HUGGINGFACE_HUB_CACHE,
f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
) )
log_master( log_master(

View File

@ -29,15 +29,6 @@ if cuda_graphs is not None:
CUDA_GRAPHS = cuda_graphs CUDA_GRAPHS = cuda_graphs
# This is overridden at model loading.
MODEL_ID = None
def set_model_id(model_id: str):
global MODEL_ID
MODEL_ID = model_id
# NOTE: eventually we should move this into the router and pass back the # NOTE: eventually we should move this into the router and pass back the
# index in all cases. # index in all cases.
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None

View File

@ -30,7 +30,7 @@ except (ImportError, NotImplementedError):
from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.globals import set_model_id, set_adapter_to_index from text_generation_server.models.globals import set_adapter_to_index
class SignalHandler: class SignalHandler:
@ -271,7 +271,6 @@ def serve(
while signal_handler.KEEP_PROCESSING: while signal_handler.KEEP_PROCESSING:
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
set_model_id(model_id)
asyncio.run( asyncio.run(
serve_inner( serve_inner(
model_id, model_id,