diff --git a/other/vllm/vllm_api_server.py b/other/vllm/vllm_api_server.py index 6d0ad24..1797400 100755 --- a/other/vllm/vllm_api_server.py +++ b/other/vllm/vllm_api_server.py @@ -1,12 +1,13 @@ import argparse import json import subprocess +import sys import time from pathlib import Path from typing import AsyncGenerator import uvicorn -from fastapi import BackgroundTasks, FastAPI, Request +from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -15,8 +16,9 @@ from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import random_uuid TIMEOUT_KEEP_ALIVE = 5 # seconds. -TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds. app = FastAPI() +engine = None +tokenizer = None served_model = None model_config = {} @@ -36,8 +38,14 @@ def get_gpu_pstate(): return None +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + @app.get("/info") -async def generate(request: Request) -> Response: +async def info(request: Request) -> Response: return JSONResponse({ 'uptime': int(time.time() - startup_time), 'startup_time': startup_time, @@ -50,12 +58,12 @@ async def generate(request: Request) -> Response: @app.get("/model") -async def generate(request: Request) -> Response: +async def model(request: Request) -> Response: return JSONResponse({'model': served_model, 'timestamp': int(time.time())}) @app.post("/tokenize") -async def generate(request: Request) -> Response: +async def tokenize(request: Request) -> Response: request_dict = await request.json() to_tokenize = request_dict.get("input") if not to_tokenize: @@ -82,6 +90,7 @@ async def generate(request: Request) -> Response: stream = request_dict.pop("stream", False) sampling_params = SamplingParams(**request_dict) request_id = random_uuid() + results_generator = engine.generate(prompt, sampling_params, request_id) # Streaming case @@ -94,14 +103,8 @@ async def generate(request: Request) -> Response: ret = {"text": text_outputs} yield (json.dumps(ret) + "\0").encode("utf-8") - async def abort_request() -> None: - await engine.abort(request_id) - if stream: - background_tasks = BackgroundTasks() - # Abort the request if the client disconnects. - background_tasks.add_task(abort_request) - return StreamingResponse(stream_results(), background=background_tasks) + return StreamingResponse(stream_results()) # Non-streaming case final_output = None @@ -121,19 +124,20 @@ async def generate(request: Request) -> Response: if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--host", type=str, default=None) parser.add_argument("--port", type=int, default=8000) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() engine_args = AsyncEngineArgs.from_cli_args(args) + # Do some init setup before loading the model. model_path = Path(args.model) served_model = model_path.name - try: model_config = json.loads((model_path / 'config.json').read_text()) except Exception as e: print(f"Failed to load the model's config - {e.__class__.__name__}: {e}") + sys.exit(1) engine = AsyncLLMEngine.from_engine_args(engine_args) tokenizer = get_tokenizer(engine_args.tokenizer,