diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 3df6e911..6684c3de 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -38,9 +38,9 @@ struct Args { port: u16, #[clap(default_value = "/tmp/text-generation-server", long, env)] shard_uds_path: String, - #[clap(default_value = "localhost", long, env)] + #[clap(default_value = "0.0.0.0", long, env)] master_addr: String, - #[clap(default_value = "29500", long, env)] + #[clap(default_value = "6000", long, env)] master_port: usize, #[clap(long, env)] json_output: bool, diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py index 943d45e9..291705a6 100644 --- a/server/text_generation/models/__init__.py +++ b/server/text_generation/models/__init__.py @@ -28,6 +28,9 @@ torch.backends.cuda.matmul.allow_tf32 = True # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. torch.backends.cudnn.allow_tf32 = True +# Disable gradients +torch.set_grad_enabled(False) + def get_model( model_id: str, revision: Optional[str], sharded: bool, quantize: bool diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index f21423ea..0cbed22e 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -289,17 +289,12 @@ class CausalLM(Model): def generate_token( self, batch: CausalLMBatch ) -> Tuple[List[Generation], Optional[CausalLMBatch]]: - # For some reason, inference_mode does not work well with GLOO which we use on CPU - context_manager = ( - torch.no_grad if self.device.type == "cpu" else torch.inference_mode + logits, past = self.forward( + batch.input_ids, + batch.attention_mask, + batch.position_ids, + batch.past_key_values, ) - with context_manager(): - logits, past = self.forward( - batch.input_ids, - batch.attention_mask, - batch.position_ids, - batch.past_key_values, - ) # List of indices to cache next_batch_keep_indices = [] diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 27cbe1c0..80aecbac 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -364,19 +364,14 @@ class Seq2SeqLM(Model): def generate_token( self, batch: Seq2SeqLMBatch ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]: - # For some reason, inference_mode does not work well with GLOO which we use on CPU - context_manager = ( - torch.no_grad if self.device.type == "cpu" else torch.inference_mode + logits, encoder_last_hidden_state, past = self.forward( + batch.input_ids, + batch.attention_mask, + batch.decoder_input_ids, + batch.decoder_attention_mask, + batch.encoder_last_hidden_state, + batch.past_key_values, ) - with context_manager(): - logits, encoder_last_hidden_state, past = self.forward( - batch.input_ids, - batch.attention_mask, - batch.decoder_input_ids, - batch.decoder_attention_mask, - batch.encoder_last_hidden_state, - batch.past_key_values, - ) # List of indices to cache next_batch_keep_indices = [] diff --git a/server/text_generation/server.py b/server/text_generation/server.py index a8a9da6c..68dc7cd0 100644 --- a/server/text_generation/server.py +++ b/server/text_generation/server.py @@ -1,5 +1,6 @@ import asyncio import os +import torch from grpc import aio from loguru import logger @@ -19,6 +20,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): self.cache = cache self.model = model self.server_urls = server_urls + # For some reason, inference_mode does not work well with GLOO which we use on CPU + if model.device.type == "cuda": + # Force inference mode for the lifetime of TextGenerationService + self._inference_mode_raii_guard = torch._C._InferenceMode(True) async def ServiceDiscovery(self, request, context): return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) @@ -89,7 +94,11 @@ def serve( local_url = unix_socket_template.format(uds_path, 0) server_urls = [local_url] - model = get_model(model_id, revision, sharded, quantize) + try: + model = get_model(model_id, revision, sharded, quantize) + except Exception: + logger.exception("Error when initializing model") + raise server = aio.server(interceptors=[ExceptionInterceptor()]) generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( @@ -101,8 +110,11 @@ def serve( ) reflection.enable_server_reflection(SERVICE_NAMES, server) server.add_insecure_port(local_url) + await server.start() + logger.info("Server started at {}".format(local_url)) + try: await server.wait_for_termination() except KeyboardInterrupt: diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index 83458969..2821a124 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -171,9 +171,14 @@ def initialize_torch_distributed(): else: backend = "gloo" + master_ip = os.getenv("MASTER_ADDR", "0.0.0.0") + master_port = os.getenv("MASTER_PORT", "6000") + init_method = f"tcp://{master_ip}:{master_port}" + # Call the init process. torch.distributed.init_process_group( backend=backend, + init_method=init_method, world_size=world_size, rank=rank, timeout=timedelta(seconds=60),