update VLLM api server to upstream

This commit is contained in:
Cyberes 2024-01-10 15:46:04 -07:00
parent 6785079fea
commit e21be17d9b
1 changed files with 18 additions and 14 deletions

View File

@ -1,12 +1,13 @@
import argparse import argparse
import json import json
import subprocess import subprocess
import sys
import time import time
from pathlib import Path from pathlib import Path
from typing import AsyncGenerator from typing import AsyncGenerator
import uvicorn import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -15,8 +16,9 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
app = FastAPI() app = FastAPI()
engine = None
tokenizer = None
served_model = None served_model = None
model_config = {} model_config = {}
@ -36,8 +38,14 @@ def get_gpu_pstate():
return None return None
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.get("/info") @app.get("/info")
async def generate(request: Request) -> Response: async def info(request: Request) -> Response:
return JSONResponse({ return JSONResponse({
'uptime': int(time.time() - startup_time), 'uptime': int(time.time() - startup_time),
'startup_time': startup_time, 'startup_time': startup_time,
@ -50,12 +58,12 @@ async def generate(request: Request) -> Response:
@app.get("/model") @app.get("/model")
async def generate(request: Request) -> Response: async def model(request: Request) -> Response:
return JSONResponse({'model': served_model, 'timestamp': int(time.time())}) return JSONResponse({'model': served_model, 'timestamp': int(time.time())})
@app.post("/tokenize") @app.post("/tokenize")
async def generate(request: Request) -> Response: async def tokenize(request: Request) -> Response:
request_dict = await request.json() request_dict = await request.json()
to_tokenize = request_dict.get("input") to_tokenize = request_dict.get("input")
if not to_tokenize: if not to_tokenize:
@ -82,6 +90,7 @@ async def generate(request: Request) -> Response:
stream = request_dict.pop("stream", False) stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict) sampling_params = SamplingParams(**request_dict)
request_id = random_uuid() request_id = random_uuid()
results_generator = engine.generate(prompt, sampling_params, request_id) results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case # Streaming case
@ -94,14 +103,8 @@ async def generate(request: Request) -> Response:
ret = {"text": text_outputs} ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8") yield (json.dumps(ret) + "\0").encode("utf-8")
async def abort_request() -> None:
await engine.abort(request_id)
if stream: if stream:
background_tasks = BackgroundTasks() return StreamingResponse(stream_results())
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(stream_results(), background=background_tasks)
# Non-streaming case # Non-streaming case
final_output = None final_output = None
@ -121,19 +124,20 @@ async def generate(request: Request) -> Response:
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
# Do some init setup before loading the model.
model_path = Path(args.model) model_path = Path(args.model)
served_model = model_path.name served_model = model_path.name
try: try:
model_config = json.loads((model_path / 'config.json').read_text()) model_config = json.loads((model_path / 'config.json').read_text())
except Exception as e: except Exception as e:
print(f"Failed to load the model's config - {e.__class__.__name__}: {e}") print(f"Failed to load the model's config - {e.__class__.__name__}: {e}")
sys.exit(1)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
tokenizer = get_tokenizer(engine_args.tokenizer, tokenizer = get_tokenizer(engine_args.tokenizer,