Update 'other/vllm/vllm_api_server.py'
This commit is contained in:
parent
c888f5c789
commit
2e998344d6
|
@ -1,5 +1,6 @@
|
|||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
|
@ -7,7 +8,7 @@ from typing import AsyncGenerator
|
|||
import uvicorn
|
||||
from fastapi import BackgroundTasks, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
@ -18,10 +19,36 @@ TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
|
|||
app = FastAPI()
|
||||
|
||||
served_model = None
|
||||
model_config = {}
|
||||
startup_time = int(time.time())
|
||||
|
||||
|
||||
# TODO: figure out ROPE scaling
|
||||
# TODO: make sure error messages are returned in the response
|
||||
|
||||
def get_gpu_pstate():
|
||||
cmd = ['nvidia-smi', '--query-gpu=pstate', '--format=csv']
|
||||
output = subprocess.check_output(cmd).decode('utf-8')
|
||||
lines = output.strip().split('\n')
|
||||
if len(lines) > 1:
|
||||
return int(lines[1].strip('P'))
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@app.get("/info")
|
||||
async def generate(request: Request) -> Response:
|
||||
return JSONResponse({
|
||||
'uptime': int(time.time() - startup_time),
|
||||
'startup_time': startup_time,
|
||||
'model': served_model,
|
||||
'model_config': model_config,
|
||||
'nvidia': {
|
||||
'pstate': get_gpu_pstate()
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@app.get("/model")
|
||||
async def generate(request: Request) -> Response:
|
||||
return JSONResponse({'model': served_model, 'timestamp': int(time.time())})
|
||||
|
@ -98,11 +125,17 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
served_model = Path(args.model).name
|
||||
model_path = Path(args.model)
|
||||
served_model = model_path.name
|
||||
|
||||
try:
|
||||
model_config = json.loads((model_path / 'config.json').read_text())
|
||||
except Exception as e:
|
||||
print(f"Failed to load the model's config - {e.__class__.__name__}: {e}")
|
||||
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
tokenizer = get_tokenizer(engine_args.tokenizer,
|
||||
tokenizer_mode=args.tokenizer_mode,
|
||||
trust_remote_code=args.trust_remote_code)
|
||||
|
|
Reference in New Issue