hf_text-generation-inference/server/bloom_inference/cli.py

43 lines
905 B
Python
Raw Normal View History

2022-10-17 06:59:00 -06:00
import typer
from pathlib import Path
from torch.distributed.launcher import launch_agent, LaunchConfig
from typing import Optional
from bloom_inference import server
app = typer.Typer()
@app.command()
def launcher(
model_name: str,
num_gpus: int = 1,
shard_directory: Optional[Path] = None,
):
if num_gpus == 1:
serve(model_name, False, shard_directory)
else:
config = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=num_gpus,
rdzv_backend="c10d",
max_restarts=0,
)
launch_agent(config, server.serve, [model_name, True, shard_directory])
@app.command()
def serve(
model_name: str,
sharded: bool = False,
shard_directory: Optional[Path] = None,
):
server.serve(model_name, sharded, shard_directory)
if __name__ == "__main__":
app()