hf_text-generation-inference/server/text_generation/cli.py

116 lines
3.2 KiB
Python

import os
import sys
import typer
from pathlib import Path
from loguru import logger
from typing import Optional
from text_generation import server, utils
from text_generation.tracing import setup_tracing
app = typer.Typer()
@app.command()
def serve(
model_id: str,
revision: Optional[str] = None,
sharded: bool = False,
quantize: bool = False,
uds_path: Path = "/tmp/text-generation",
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
):
if sharded:
assert (
os.getenv("RANK", None) is not None
), "RANK must be set when sharded is True"
assert (
os.getenv("WORLD_SIZE", None) is not None
), "WORLD_SIZE must be set when sharded is True"
assert (
os.getenv("MASTER_ADDR", None) is not None
), "MASTER_ADDR must be set when sharded is True"
assert (
os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True"
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
# Setup OpenTelemetry distributed tracing
if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
server.serve(model_id, revision, sharded, quantize, uds_path)
@app.command()
def download_weights(
model_id: str,
revision: Optional[str] = None,
extension: str = ".safetensors",
logger_level: str = "INFO",
json_output: bool = False,
):
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
# Test if files were already download
try:
utils.weight_files(model_id, revision, extension)
logger.info(
"Files are already present in the local cache. " "Skipping download."
)
return
# Local files not found
except utils.LocalEntryNotFoundError:
pass
# Download weights directly
try:
filenames = utils.weight_hub_files(model_id, revision, extension)
utils.download_weights(filenames, model_id, revision)
except utils.EntryNotFoundError as e:
if not extension == ".safetensors":
raise e
logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Converting PyTorch weights instead."
)
# Try to see if there are pytorch weights
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
# Download pytorch weights
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
for p in local_pt_files
]
# Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files)
if __name__ == "__main__":
app()