fix(server): better handling of inference mode (#57)

This commit is contained in:
OlivierDehaene 2023-02-07 15:38:22 +01:00 committed by GitHub
parent e114d87486
commit 4acc42a605
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 35 additions and 25 deletions

View File

@ -38,9 +38,9 @@ struct Args {
port: u16,
#[clap(default_value = "/tmp/text-generation-server", long, env)]
shard_uds_path: String,
#[clap(default_value = "localhost", long, env)]
#[clap(default_value = "0.0.0.0", long, env)]
master_addr: String,
#[clap(default_value = "29500", long, env)]
#[clap(default_value = "6000", long, env)]
master_port: usize,
#[clap(long, env)]
json_output: bool,

View File

@ -28,6 +28,9 @@ torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
# Disable gradients
torch.set_grad_enabled(False)
def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: bool

View File

@ -289,17 +289,12 @@ class CausalLM(Model):
def generate_token(
self, batch: CausalLMBatch
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = (
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
logits, past = self.forward(
batch.input_ids,
batch.attention_mask,
batch.position_ids,
batch.past_key_values,
)
with context_manager():
logits, past = self.forward(
batch.input_ids,
batch.attention_mask,
batch.position_ids,
batch.past_key_values,
)
# List of indices to cache
next_batch_keep_indices = []

View File

@ -364,19 +364,14 @@ class Seq2SeqLM(Model):
def generate_token(
self, batch: Seq2SeqLMBatch
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = (
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
logits, encoder_last_hidden_state, past = self.forward(
batch.input_ids,
batch.attention_mask,
batch.decoder_input_ids,
batch.decoder_attention_mask,
batch.encoder_last_hidden_state,
batch.past_key_values,
)
with context_manager():
logits, encoder_last_hidden_state, past = self.forward(
batch.input_ids,
batch.attention_mask,
batch.decoder_input_ids,
batch.decoder_attention_mask,
batch.encoder_last_hidden_state,
batch.past_key_values,
)
# List of indices to cache
next_batch_keep_indices = []

View File

@ -1,5 +1,6 @@
import asyncio
import os
import torch
from grpc import aio
from loguru import logger
@ -19,6 +20,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
self.cache = cache
self.model = model
self.server_urls = server_urls
# For some reason, inference_mode does not work well with GLOO which we use on CPU
if model.device.type == "cuda":
# Force inference mode for the lifetime of TextGenerationService
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
async def ServiceDiscovery(self, request, context):
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
@ -89,7 +94,11 @@ def serve(
local_url = unix_socket_template.format(uds_path, 0)
server_urls = [local_url]
model = get_model(model_id, revision, sharded, quantize)
try:
model = get_model(model_id, revision, sharded, quantize)
except Exception:
logger.exception("Error when initializing model")
raise
server = aio.server(interceptors=[ExceptionInterceptor()])
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
@ -101,8 +110,11 @@ def serve(
)
reflection.enable_server_reflection(SERVICE_NAMES, server)
server.add_insecure_port(local_url)
await server.start()
logger.info("Server started at {}".format(local_url))
try:
await server.wait_for_termination()
except KeyboardInterrupt:

View File

@ -171,9 +171,14 @@ def initialize_torch_distributed():
else:
backend = "gloo"
master_ip = os.getenv("MASTER_ADDR", "0.0.0.0")
master_port = os.getenv("MASTER_PORT", "6000")
init_method = f"tcp://{master_ip}:{master_port}"
# Call the init process.
torch.distributed.init_process_group(
backend=backend,
init_method=init_method,
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=60),