Update 'other/vllm/vllm_api_server.py'

This commit is contained in:
Cyberes 2023-09-29 22:36:03 -06:00
parent c888f5c789
commit 2e998344d6
1 changed files with 37 additions and 4 deletions

View File

@ -1,5 +1,6 @@
import argparse import argparse
import json import json
import subprocess
import time import time
from pathlib import Path from pathlib import Path
from typing import AsyncGenerator from typing import AsyncGenerator
@ -7,7 +8,7 @@ from typing import AsyncGenerator
import uvicorn import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
@ -18,10 +19,36 @@ TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
app = FastAPI() app = FastAPI()
served_model = None served_model = None
model_config = {}
startup_time = int(time.time())
# TODO: figure out ROPE scaling # TODO: figure out ROPE scaling
# TODO: make sure error messages are returned in the response # TODO: make sure error messages are returned in the response
def get_gpu_pstate():
cmd = ['nvidia-smi', '--query-gpu=pstate', '--format=csv']
output = subprocess.check_output(cmd).decode('utf-8')
lines = output.strip().split('\n')
if len(lines) > 1:
return int(lines[1].strip('P'))
else:
return None
@app.get("/info")
async def generate(request: Request) -> Response:
return JSONResponse({
'uptime': int(time.time() - startup_time),
'startup_time': startup_time,
'model': served_model,
'model_config': model_config,
'nvidia': {
'pstate': get_gpu_pstate()
}
})
@app.get("/model") @app.get("/model")
async def generate(request: Request) -> Response: async def generate(request: Request) -> Response:
return JSONResponse({'model': served_model, 'timestamp': int(time.time())}) return JSONResponse({'model': served_model, 'timestamp': int(time.time())})
@ -98,11 +125,17 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
served_model = Path(args.model).name model_path = Path(args.model)
served_model = model_path.name
try:
model_config = json.loads((model_path / 'config.json').read_text())
except Exception as e:
print(f"Failed to load the model's config - {e.__class__.__name__}: {e}")
engine = AsyncLLMEngine.from_engine_args(engine_args)
tokenizer = get_tokenizer(engine_args.tokenizer, tokenizer = get_tokenizer(engine_args.tokenizer,
tokenizer_mode=args.tokenizer_mode, tokenizer_mode=args.tokenizer_mode,
trust_remote_code=args.trust_remote_code) trust_remote_code=args.trust_remote_code)