Pr 2290 ci run (#2329)
* MODEL_ID propagation fix * fix: remove global model id --------- Co-authored-by: root <root@tw031.pit.tensorwave.lan>
This commit is contained in:
parent
34f7dcfd80
commit
f7f61876cf
|
@ -43,7 +43,6 @@ from text_generation_server.models.globals import (
|
||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
CUDA_GRAPHS,
|
CUDA_GRAPHS,
|
||||||
get_adapter_to_index,
|
get_adapter_to_index,
|
||||||
MODEL_ID,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
|
@ -1156,7 +1155,7 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
tunableop_filepath = os.path.join(
|
tunableop_filepath = os.path.join(
|
||||||
HUGGINGFACE_HUB_CACHE,
|
HUGGINGFACE_HUB_CACHE,
|
||||||
f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
||||||
)
|
)
|
||||||
|
|
||||||
log_master(
|
log_master(
|
||||||
|
|
|
@ -29,15 +29,6 @@ if cuda_graphs is not None:
|
||||||
|
|
||||||
CUDA_GRAPHS = cuda_graphs
|
CUDA_GRAPHS = cuda_graphs
|
||||||
|
|
||||||
# This is overridden at model loading.
|
|
||||||
MODEL_ID = None
|
|
||||||
|
|
||||||
|
|
||||||
def set_model_id(model_id: str):
|
|
||||||
global MODEL_ID
|
|
||||||
MODEL_ID = model_id
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: eventually we should move this into the router and pass back the
|
# NOTE: eventually we should move this into the router and pass back the
|
||||||
# index in all cases.
|
# index in all cases.
|
||||||
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
|
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
|
||||||
|
|
|
@ -30,7 +30,7 @@ except (ImportError, NotImplementedError):
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
from text_generation_server.models.globals import set_model_id, set_adapter_to_index
|
from text_generation_server.models.globals import set_adapter_to_index
|
||||||
|
|
||||||
|
|
||||||
class SignalHandler:
|
class SignalHandler:
|
||||||
|
@ -271,7 +271,6 @@ def serve(
|
||||||
while signal_handler.KEEP_PROCESSING:
|
while signal_handler.KEEP_PROCESSING:
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
set_model_id(model_id)
|
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
serve_inner(
|
serve_inner(
|
||||||
model_id,
|
model_id,
|
||||||
|
|
Loading…
Reference in New Issue