update VLLM api server to upstream

This commit is contained in:
Cyberes 2024-01-10 15:46:04 -07:00
parent 6785079fea
commit e21be17d9b
1 changed files with 18 additions and 14 deletions

View File

@ -1,12 +1,13 @@
import argparse
import json
import subprocess
import sys
import time
from pathlib import Path
from typing import AsyncGenerator
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -15,8 +16,9 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
app = FastAPI()
engine = None
tokenizer = None
served_model = None
model_config = {}
@ -36,8 +38,14 @@ def get_gpu_pstate():
return None
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.get("/info")
async def generate(request: Request) -> Response:
async def info(request: Request) -> Response:
return JSONResponse({
'uptime': int(time.time() - startup_time),
'startup_time': startup_time,
@ -50,12 +58,12 @@ async def generate(request: Request) -> Response:
@app.get("/model")
async def generate(request: Request) -> Response:
async def model(request: Request) -> Response:
return JSONResponse({'model': served_model, 'timestamp': int(time.time())})
@app.post("/tokenize")
async def generate(request: Request) -> Response:
async def tokenize(request: Request) -> Response:
request_dict = await request.json()
to_tokenize = request_dict.get("input")
if not to_tokenize:
@ -82,6 +90,7 @@ async def generate(request: Request) -> Response:
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case
@ -94,14 +103,8 @@ async def generate(request: Request) -> Response:
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8")
async def abort_request() -> None:
await engine.abort(request_id)
if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(stream_results(), background=background_tasks)
return StreamingResponse(stream_results())
# Non-streaming case
final_output = None
@ -121,19 +124,20 @@ async def generate(request: Request) -> Response:
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
# Do some init setup before loading the model.
model_path = Path(args.model)
served_model = model_path.name
try:
model_config = json.loads((model_path / 'config.json').read_text())
except Exception as e:
print(f"Failed to load the model's config - {e.__class__.__name__}: {e}")
sys.exit(1)
engine = AsyncLLMEngine.from_engine_args(engine_args)
tokenizer = get_tokenizer(engine_args.tokenizer,