import argparse import json import subprocess import time from pathlib import Path from typing import AsyncGenerator import uvicorn from fastapi import BackgroundTasks, FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import random_uuid TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds. app = FastAPI() served_model = None model_config = {} startup_time = int(time.time()) # TODO: figure out ROPE scaling # TODO: make sure error messages are returned in the response def get_gpu_pstate(): cmd = ['nvidia-smi', '--query-gpu=pstate', '--format=csv'] output = subprocess.check_output(cmd).decode('utf-8') lines = output.strip().split('\n') if len(lines) > 1: return int(lines[1].strip('P')) else: return None @app.get("/info") async def generate(request: Request) -> Response: return JSONResponse({ 'uptime': int(time.time() - startup_time), 'startup_time': startup_time, 'model': served_model, 'model_config': model_config, 'nvidia': { 'pstate': get_gpu_pstate() } }) @app.get("/model") async def generate(request: Request) -> Response: return JSONResponse({'model': served_model, 'timestamp': int(time.time())}) @app.post("/tokenize") async def generate(request: Request) -> Response: request_dict = await request.json() to_tokenize = request_dict.get("input") if not to_tokenize: JSONResponse({'error': 'must have input field'}, status_code=400) tokens = tokenizer.tokenize(to_tokenize) response = {} if request_dict.get("return", False): response['tokens'] = tokens response['length'] = len(tokens) return JSONResponse(response) @app.post("/generate") async def generate(request: Request) -> Response: """Generate completion for the request. The request should be a JSON object with the following fields: - prompt: the prompt to use for the generation. - stream: whether to stream the results or not. - other fields: the sampling parameters (See `SamplingParams` for details). """ request_dict = await request.json() prompt = request_dict.pop("prompt") stream = request_dict.pop("stream", False) sampling_params = SamplingParams(**request_dict) request_id = random_uuid() results_generator = engine.generate(prompt, sampling_params, request_id) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output in results_generator: prompt = request_output.prompt text_outputs = [ prompt + output.text for output in request_output.outputs ] ret = {"text": text_outputs} yield (json.dumps(ret) + "\0").encode("utf-8") async def abort_request() -> None: await engine.abort(request_id) if stream: background_tasks = BackgroundTasks() # Abort the request if the client disconnects. background_tasks.add_task(abort_request) return StreamingResponse(stream_results(), background=background_tasks) # Non-streaming case final_output = None async for request_output in results_generator: if await request.is_disconnected(): # Abort the request if the client disconnects. await engine.abort(request_id) return Response(status_code=499) final_output = request_output assert final_output is not None prompt = final_output.prompt text_outputs = [prompt + output.text for output in final_output.outputs] ret = {"text": text_outputs} return JSONResponse(ret) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() engine_args = AsyncEngineArgs.from_cli_args(args) model_path = Path(args.model) served_model = model_path.name try: model_config = json.loads((model_path / 'config.json').read_text()) except Exception as e: print(f"Failed to load the model's config - {e.__class__.__name__}: {e}") engine = AsyncLLMEngine.from_engine_args(engine_args) tokenizer = get_tokenizer(engine_args.tokenizer, tokenizer_mode=args.tokenizer_mode, trust_remote_code=args.trust_remote_code) uvicorn.run(app, host=args.host, port=args.port, log_level="debug", timeout_keep_alive=TIMEOUT_KEEP_ALIVE)