hf_text-generation-inference/server/bloom_inference/cli.py

53 lines
1.3 KiB
Python
Raw Normal View History

2022-10-18 07:19:03 -06:00
import os
2022-10-17 06:59:00 -06:00
import typer
from pathlib import Path
from typing import Optional
2022-10-18 07:19:03 -06:00
from bloom_inference import prepare_weights, server
2022-10-17 06:59:00 -06:00
app = typer.Typer()
@app.command()
2022-10-18 07:19:03 -06:00
def serve(
model_name: str,
sharded: bool = False,
shard_directory: Optional[Path] = None,
uds_path: Path = "/tmp/bloom-inference",
2022-10-17 06:59:00 -06:00
):
2022-10-18 07:19:03 -06:00
if sharded:
assert (
shard_directory is not None
), "shard_directory must be set when sharded is True"
assert (
os.getenv("RANK", None) is not None
), "RANK must be set when sharded is True"
assert (
os.getenv("WORLD_SIZE", None) is not None
), "WORLD_SIZE must be set when sharded is True"
assert (
os.getenv("MASTER_ADDR", None) is not None
), "MASTER_ADDR must be set when sharded is True"
assert (
os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True"
server.serve(model_name, sharded, uds_path, shard_directory)
2022-10-17 06:59:00 -06:00
@app.command()
2022-10-18 07:19:03 -06:00
def prepare_weights(
model_name: str,
shard_directory: Path,
cache_directory: Path,
num_shard: int = 1,
2022-10-17 06:59:00 -06:00
):
2022-10-18 07:19:03 -06:00
prepare_weights.prepare_weights(
model_name, cache_directory, shard_directory, num_shard
)
2022-10-17 06:59:00 -06:00
if __name__ == "__main__":
app()