2022-10-18 07:19:03 -06:00
|
|
|
import os
|
2022-10-17 06:59:00 -06:00
|
|
|
import typer
|
|
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
from typing import Optional
|
|
|
|
|
2022-10-18 07:19:03 -06:00
|
|
|
from bloom_inference import prepare_weights, server
|
2022-10-17 06:59:00 -06:00
|
|
|
|
|
|
|
app = typer.Typer()
|
|
|
|
|
|
|
|
|
|
|
|
@app.command()
|
2022-10-18 07:19:03 -06:00
|
|
|
def serve(
|
|
|
|
model_name: str,
|
|
|
|
sharded: bool = False,
|
|
|
|
shard_directory: Optional[Path] = None,
|
|
|
|
uds_path: Path = "/tmp/bloom-inference",
|
2022-10-17 06:59:00 -06:00
|
|
|
):
|
2022-10-18 07:19:03 -06:00
|
|
|
if sharded:
|
|
|
|
assert (
|
|
|
|
shard_directory is not None
|
|
|
|
), "shard_directory must be set when sharded is True"
|
|
|
|
assert (
|
|
|
|
os.getenv("RANK", None) is not None
|
|
|
|
), "RANK must be set when sharded is True"
|
|
|
|
assert (
|
|
|
|
os.getenv("WORLD_SIZE", None) is not None
|
|
|
|
), "WORLD_SIZE must be set when sharded is True"
|
|
|
|
assert (
|
|
|
|
os.getenv("MASTER_ADDR", None) is not None
|
|
|
|
), "MASTER_ADDR must be set when sharded is True"
|
|
|
|
assert (
|
|
|
|
os.getenv("MASTER_PORT", None) is not None
|
|
|
|
), "MASTER_PORT must be set when sharded is True"
|
|
|
|
|
|
|
|
server.serve(model_name, sharded, uds_path, shard_directory)
|
2022-10-17 06:59:00 -06:00
|
|
|
|
|
|
|
|
|
|
|
@app.command()
|
2022-10-18 07:19:03 -06:00
|
|
|
def prepare_weights(
|
|
|
|
model_name: str,
|
|
|
|
shard_directory: Path,
|
|
|
|
cache_directory: Path,
|
|
|
|
num_shard: int = 1,
|
2022-10-17 06:59:00 -06:00
|
|
|
):
|
2022-10-18 07:19:03 -06:00
|
|
|
prepare_weights.prepare_weights(
|
|
|
|
model_name, cache_directory, shard_directory, num_shard
|
|
|
|
)
|
2022-10-17 06:59:00 -06:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
app()
|