2022-10-18 07:19:03 -06:00
|
|
|
import os
|
2023-01-05 04:01:23 -07:00
|
|
|
import sys
|
2022-10-17 06:59:00 -06:00
|
|
|
import typer
|
|
|
|
|
|
|
|
from pathlib import Path
|
2023-01-05 04:01:23 -07:00
|
|
|
from loguru import logger
|
2023-01-31 10:53:56 -07:00
|
|
|
from typing import Optional
|
2022-10-17 06:59:00 -06:00
|
|
|
|
2022-10-28 11:24:00 -06:00
|
|
|
from text_generation import server, utils
|
2023-02-13 05:02:45 -07:00
|
|
|
from text_generation.tracing import setup_tracing
|
2022-10-17 06:59:00 -06:00
|
|
|
|
|
|
|
app = typer.Typer()
|
|
|
|
|
|
|
|
|
|
|
|
@app.command()
|
2022-10-18 07:19:03 -06:00
|
|
|
def serve(
|
2023-02-03 04:43:37 -07:00
|
|
|
model_id: str,
|
2023-01-31 10:53:56 -07:00
|
|
|
revision: Optional[str] = None,
|
2022-10-18 07:19:03 -06:00
|
|
|
sharded: bool = False,
|
2022-10-27 06:25:29 -06:00
|
|
|
quantize: bool = False,
|
2022-10-28 11:24:00 -06:00
|
|
|
uds_path: Path = "/tmp/text-generation",
|
2023-01-05 04:01:23 -07:00
|
|
|
logger_level: str = "INFO",
|
|
|
|
json_output: bool = False,
|
2023-02-13 05:02:45 -07:00
|
|
|
otlp_endpoint: Optional[str] = None,
|
2022-10-17 06:59:00 -06:00
|
|
|
):
|
2022-10-18 07:19:03 -06:00
|
|
|
if sharded:
|
|
|
|
assert (
|
|
|
|
os.getenv("RANK", None) is not None
|
|
|
|
), "RANK must be set when sharded is True"
|
|
|
|
assert (
|
|
|
|
os.getenv("WORLD_SIZE", None) is not None
|
|
|
|
), "WORLD_SIZE must be set when sharded is True"
|
|
|
|
assert (
|
|
|
|
os.getenv("MASTER_ADDR", None) is not None
|
|
|
|
), "MASTER_ADDR must be set when sharded is True"
|
|
|
|
assert (
|
|
|
|
os.getenv("MASTER_PORT", None) is not None
|
|
|
|
), "MASTER_PORT must be set when sharded is True"
|
|
|
|
|
2023-02-13 05:02:45 -07:00
|
|
|
# Remove default handler
|
|
|
|
logger.remove()
|
|
|
|
logger.add(
|
|
|
|
sys.stdout,
|
|
|
|
format="{message}",
|
|
|
|
filter="text_generation",
|
|
|
|
level=logger_level,
|
|
|
|
serialize=json_output,
|
|
|
|
backtrace=True,
|
|
|
|
diagnose=False,
|
|
|
|
)
|
|
|
|
# Setup OpenTelemetry distributed tracing
|
|
|
|
if otlp_endpoint is not None:
|
|
|
|
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
|
|
|
|
|
2023-02-03 04:43:37 -07:00
|
|
|
server.serve(model_id, revision, sharded, quantize, uds_path)
|
2022-10-17 06:59:00 -06:00
|
|
|
|
|
|
|
|
|
|
|
@app.command()
|
2022-10-22 12:00:15 -06:00
|
|
|
def download_weights(
|
2023-02-03 04:43:37 -07:00
|
|
|
model_id: str,
|
2023-01-31 10:53:56 -07:00
|
|
|
revision: Optional[str] = None,
|
2022-10-28 11:24:00 -06:00
|
|
|
extension: str = ".safetensors",
|
2022-10-17 06:59:00 -06:00
|
|
|
):
|
2023-02-03 04:43:37 -07:00
|
|
|
utils.download_weights(model_id, revision, extension)
|
2022-10-17 06:59:00 -06:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
app()
|