hf_text-generation-inference/server/text_generation/cli.py

69 lines
1.7 KiB
Python
Raw Normal View History

2022-10-18 07:19:03 -06:00
import os
import sys
2022-10-17 06:59:00 -06:00
import typer
from pathlib import Path
from loguru import logger
2023-01-31 10:53:56 -07:00
from typing import Optional
2022-10-17 06:59:00 -06:00
from text_generation import server, utils
2023-02-13 05:02:45 -07:00
from text_generation.tracing import setup_tracing
2022-10-17 06:59:00 -06:00
app = typer.Typer()
@app.command()
2022-10-18 07:19:03 -06:00
def serve(
model_id: str,
2023-01-31 10:53:56 -07:00
revision: Optional[str] = None,
2022-10-18 07:19:03 -06:00
sharded: bool = False,
2022-10-27 06:25:29 -06:00
quantize: bool = False,
uds_path: Path = "/tmp/text-generation",
logger_level: str = "INFO",
json_output: bool = False,
2023-02-13 05:02:45 -07:00
otlp_endpoint: Optional[str] = None,
2022-10-17 06:59:00 -06:00
):
2022-10-18 07:19:03 -06:00
if sharded:
assert (
os.getenv("RANK", None) is not None
), "RANK must be set when sharded is True"
assert (
os.getenv("WORLD_SIZE", None) is not None
), "WORLD_SIZE must be set when sharded is True"
assert (
os.getenv("MASTER_ADDR", None) is not None
), "MASTER_ADDR must be set when sharded is True"
assert (
os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True"
2023-02-13 05:02:45 -07:00
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
# Setup OpenTelemetry distributed tracing
if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
server.serve(model_id, revision, sharded, quantize, uds_path)
2022-10-17 06:59:00 -06:00
@app.command()
def download_weights(
model_id: str,
2023-01-31 10:53:56 -07:00
revision: Optional[str] = None,
extension: str = ".safetensors",
2022-10-17 06:59:00 -06:00
):
utils.download_weights(model_id, revision, extension)
2022-10-17 06:59:00 -06:00
if __name__ == "__main__":
app()