43 lines
905 B
Python
43 lines
905 B
Python
|
import typer
|
||
|
|
||
|
from pathlib import Path
|
||
|
from torch.distributed.launcher import launch_agent, LaunchConfig
|
||
|
from typing import Optional
|
||
|
|
||
|
from bloom_inference import server
|
||
|
|
||
|
app = typer.Typer()
|
||
|
|
||
|
|
||
|
@app.command()
|
||
|
def launcher(
|
||
|
model_name: str,
|
||
|
num_gpus: int = 1,
|
||
|
shard_directory: Optional[Path] = None,
|
||
|
):
|
||
|
if num_gpus == 1:
|
||
|
serve(model_name, False, shard_directory)
|
||
|
|
||
|
else:
|
||
|
config = LaunchConfig(
|
||
|
min_nodes=1,
|
||
|
max_nodes=1,
|
||
|
nproc_per_node=num_gpus,
|
||
|
rdzv_backend="c10d",
|
||
|
max_restarts=0,
|
||
|
)
|
||
|
launch_agent(config, server.serve, [model_name, True, shard_directory])
|
||
|
|
||
|
|
||
|
@app.command()
|
||
|
def serve(
|
||
|
model_name: str,
|
||
|
sharded: bool = False,
|
||
|
shard_directory: Optional[Path] = None,
|
||
|
):
|
||
|
server.serve(model_name, sharded, shard_directory)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
app()
|