60 lines
1.4 KiB
Python
60 lines
1.4 KiB
Python
import os
|
|
import sys
|
|
import typer
|
|
|
|
from pathlib import Path
|
|
from loguru import logger
|
|
|
|
from text_generation import server, utils
|
|
|
|
app = typer.Typer()
|
|
|
|
|
|
@app.command()
|
|
def serve(
|
|
model_name: str,
|
|
sharded: bool = False,
|
|
quantize: bool = False,
|
|
uds_path: Path = "/tmp/text-generation",
|
|
logger_level: str = "INFO",
|
|
json_output: bool = False,
|
|
):
|
|
# Remove default handler
|
|
logger.remove()
|
|
logger.add(
|
|
sys.stdout,
|
|
format="{message}",
|
|
filter="text_generation",
|
|
level=logger_level,
|
|
serialize=json_output,
|
|
backtrace=True,
|
|
diagnose=False,
|
|
)
|
|
if sharded:
|
|
assert (
|
|
os.getenv("RANK", None) is not None
|
|
), "RANK must be set when sharded is True"
|
|
assert (
|
|
os.getenv("WORLD_SIZE", None) is not None
|
|
), "WORLD_SIZE must be set when sharded is True"
|
|
assert (
|
|
os.getenv("MASTER_ADDR", None) is not None
|
|
), "MASTER_ADDR must be set when sharded is True"
|
|
assert (
|
|
os.getenv("MASTER_PORT", None) is not None
|
|
), "MASTER_PORT must be set when sharded is True"
|
|
|
|
server.serve(model_name, sharded, quantize, uds_path)
|
|
|
|
|
|
@app.command()
|
|
def download_weights(
|
|
model_name: str,
|
|
extension: str = ".safetensors",
|
|
):
|
|
utils.download_weights(model_name, extension)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app()
|