Update 'other/vllm/vllm_api_server.py'

This commit is contained in:
Cyberes 2023-09-29 22:36:03 -06:00
parent c888f5c789
commit 2e998344d6
1 changed files with 37 additions and 4 deletions

View File

@ -1,5 +1,6 @@
import argparse
import json
import subprocess
import time
from pathlib import Path
from typing import AsyncGenerator
@ -7,7 +8,7 @@ from typing import AsyncGenerator
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
@ -18,10 +19,36 @@ TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
app = FastAPI()
served_model = None
model_config = {}
startup_time = int(time.time())
# TODO: figure out ROPE scaling
# TODO: make sure error messages are returned in the response
def get_gpu_pstate():
cmd = ['nvidia-smi', '--query-gpu=pstate', '--format=csv']
output = subprocess.check_output(cmd).decode('utf-8')
lines = output.strip().split('\n')
if len(lines) > 1:
return int(lines[1].strip('P'))
else:
return None
@app.get("/info")
async def generate(request: Request) -> Response:
return JSONResponse({
'uptime': int(time.time() - startup_time),
'startup_time': startup_time,
'model': served_model,
'model_config': model_config,
'nvidia': {
'pstate': get_gpu_pstate()
}
})
@app.get("/model")
async def generate(request: Request) -> Response:
return JSONResponse({'model': served_model, 'timestamp': int(time.time())})
@ -98,11 +125,17 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int, default=8000)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
served_model = Path(args.model).name
model_path = Path(args.model)
served_model = model_path.name
try:
model_config = json.loads((model_path / 'config.json').read_text())
except Exception as e:
print(f"Failed to load the model's config - {e.__class__.__name__}: {e}")
engine = AsyncLLMEngine.from_engine_args(engine_args)
tokenizer = get_tokenizer(engine_args.tokenizer,
tokenizer_mode=args.tokenizer_mode,
trust_remote_code=args.trust_remote_code)