actually we don't want to emulate openai

This commit is contained in:
Cyberes 2023-09-12 01:04:11 -06:00
parent 747d838138
commit 40ac84aa9a
19 changed files with 348 additions and 150 deletions

1
.gitignore vendored
View File

@ -1,6 +1,7 @@
proxy-server.db
.idea
config/config.yml
install vllm_gptq-0.1.3-py3-none-any.whl
# ---> Python
# Byte-compiled / optimized / DLL files

View File

@ -10,7 +10,8 @@ The purpose of this server is to abstract your LLM backend from your frontend AP
2. `python3 -m venv venv`
3. `source venv/bin/activate`
4. `pip install -r requirements.txt`
5. `python3 server.py`
5. `wget https://git.evulid.cc/attachments/89c87201-58b1-4e28-b8fd-d0b323c810c4 -O vllm_gptq-0.1.3-py3-none-any.whl && pip install vllm_gptq-0.1.3-py3-none-any.whl`
6. `python3 server.py`
An example systemctl service file is provided in `other/local-llm.service`.

View File

@ -18,6 +18,7 @@ def init_db():
CREATE TABLE prompts (
ip TEXT,
token TEXT DEFAULT NULL,
backend TEXT,
prompt TEXT,
prompt_tokens INTEGER,
response TEXT,
@ -71,8 +72,8 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
timestamp = int(time.time())
conn = sqlite3.connect(opts.database_path)
c = conn.cursor()
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, gen_time, opts.running_model, json.dumps(parameters), json.dumps(headers), timestamp))
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(ip, token, opts.mode, prompt, prompt_tokens, response, response_tokens, backend_response_code, gen_time, opts.running_model, json.dumps(parameters), json.dumps(headers), timestamp))
conn.commit()
conn.close()
@ -129,15 +130,17 @@ def average_column_for_model(table_name, column_name, model_name):
return result[0]
def weighted_average_column_for_model(table_name, column_name, model_name, exclude_zeros: bool = False):
def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, exclude_zeros: bool = False):
conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor()
cursor.execute(f"SELECT DISTINCT model FROM {table_name}")
models = [row[0] for row in cursor.fetchall()]
cursor.execute(f"SELECT DISTINCT model, backend FROM {table_name}")
models_backends = [(row[0], row[1]) for row in cursor.fetchall()]
model_averages = {}
for model in models:
cursor.execute(f"SELECT {column_name}, ROWID FROM {table_name} WHERE model = ? ORDER BY ROWID DESC", (model,))
for model, backend in models_backends:
if backend != backend_name:
continue
cursor.execute(f"SELECT {column_name}, ROWID FROM {table_name} WHERE model = ? AND backend = ? ORDER BY ROWID DESC", (model, backend))
results = cursor.fetchall()
if not results:
@ -155,11 +158,11 @@ def weighted_average_column_for_model(table_name, column_name, model_name, exclu
if total_weight == 0:
continue
model_averages[model] = weighted_sum / total_weight
model_averages[(model, backend)] = weighted_sum / total_weight
conn.close()
return model_averages.get(model_name)
return model_averages.get((model_name, backend_name))
def sum_column(table_name, column_name):

View File

@ -1,10 +1,11 @@
import requests
from llm_server import opts
from pathlib import Path
def get_running_model():
# TODO: cache the results for 1 min so we don't have to keep calling the backend
# TODO: only use one try/catch
if opts.mode == 'oobabooga':
try:
@ -22,11 +23,9 @@ def get_running_model():
return False, e
elif opts.mode == 'vllm':
try:
backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl)
backend_response = requests.get(f'{opts.backend_url}/model', timeout=3, verify=opts.verify_ssl)
r_json = backend_response.json()
model_name = Path(r_json['data'][0]['root']).name
# r_json['data'][0]['root'] = model_name
return model_name, None
return r_json['model'], None
except Exception as e:
return False, e
else:

View File

@ -1,4 +1,4 @@
from typing import Union, Tuple
from typing import Tuple, Union
class LLMBackend:
@ -10,3 +10,12 @@ class LLMBackend:
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# raise NotImplementedError
def get_parameters(self, parameters) -> Union[dict, None]:
"""
Validate and return the parameters for this backend.
Lets you set defaults for specific backends.
:param parameters:
:return:
"""
raise NotImplementedError

View File

@ -1,15 +1,11 @@
from typing import Tuple
import requests
from flask import jsonify
from ... import opts
from ..llm_backend import LLMBackend
from ...database import log_prompt
from ...helpers import safe_list_get
from ...routes.cache import redis
from ...routes.helpers.client import format_sillytavern_err
from ...routes.helpers.http import validate_json
from ..llm_backend import LLMBackend
class OobaboogaLLMBackend(LLMBackend):
@ -71,3 +67,7 @@ class OobaboogaLLMBackend(LLMBackend):
# return r_json['result'], None
# except Exception as e:
# return False, e
def get_parameters(self, parameters):
del parameters['prompt']
return parameters

View File

@ -1,17 +1,14 @@
"""
This file is used by the worker that processes requests.
"""
import io
import json
import time
from uuid import uuid4
import requests
from requests import Response
from llm_server import opts
from llm_server.database import tokenizer
from llm_server.routes.cache import redis
# TODO: make the VLMM backend return TPS and time elapsed
@ -19,7 +16,7 @@ from llm_server.routes.cache import redis
def prepare_json(json_data: dict):
# logit_bias is not currently supported
del json_data['logit_bias']
# del json_data['logit_bias']
return json_data
@ -83,26 +80,26 @@ def transform_prompt_to_text(prompt: list):
def handle_blocking_request(json_data: dict):
try:
r = requests.post(f'{opts.backend_url}/v1/chat/completions', json=prepare_json(json_data), verify=opts.verify_ssl)
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl)
except Exception as e:
return False, None, f'{e.__class__.__name__}: {e}'
# TODO: check for error here?
response_json = r.json()
response_json['error'] = False
# response_json = r.json()
# response_json['error'] = False
new_response = Response()
new_response.status_code = r.status_code
new_response._content = json.dumps(response_json).encode('utf-8')
new_response.raw = io.BytesIO(new_response._content)
new_response.headers = r.headers
new_response.url = r.url
new_response.reason = r.reason
new_response.cookies = r.cookies
new_response.elapsed = r.elapsed
new_response.request = r.request
# new_response = Response()
# new_response.status_code = r.status_code
# new_response._content = json.dumps(response_json).encode('utf-8')
# new_response.raw = io.BytesIO(new_response._content)
# new_response.headers = r.headers
# new_response.url = r.url
# new_response.reason = r.reason
# new_response.cookies = r.cookies
# new_response.elapsed = r.elapsed
# new_response.request = r.request
return True, new_response, None
return True, r, None
def generate(json_data: dict):

View File

@ -1,13 +1,8 @@
from pathlib import Path
import requests
from llm_server import opts
def get_vlmm_models_info():
backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl)
r_json = backend_response.json()
r_json['data'][0]['root'] = Path(r_json['data'][0]['root']).name
r_json['data'][0]['id'] = Path(r_json['data'][0]['id']).name
return r_json
vllm_info = """<p><strong>Important:</strong> This endpoint is running <a href="https://github.com/chu-tianxiang/vllm-gptq" target="_blank">vllm-gptq</a> and not all Oobabooga parameters are supported.</p>
<strong>Supported Parameters:</strong>
<ul>
<li><kbd>temperature</kbd></li>
<li><kbd>top_p</kbd></li>
<li><kbd>top_k</kbd></li>
<li><kbd>max_new_tokens</kbd></li>
</ul>"""

View File

@ -1,8 +1,9 @@
from typing import Tuple
from flask import jsonify
from vllm import SamplingParams
from llm_server.database import log_prompt
from llm_server.helpers import indefinite_article
from llm_server.llm.llm_backend import LLMBackend
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.helpers.http import validate_json
@ -22,19 +23,23 @@ class VLLMBackend(LLMBackend):
response_status_code = 0
if response_valid_json:
backend_response = response_json_body
if len(response_json_body.get('text', [])):
# Does vllm return the prompt and the response together???
backend_response = response_json_body['text'][0].split(prompt)[1].strip(' ').strip('\n')
else:
# Failsafe
backend_response = ''
if response_json_body.get('error'):
backend_err = True
error_type = response_json_body.get('error_type')
error_type_string = f'returned {indefinite_article(error_type)} {error_type} error'
backend_response = format_sillytavern_err(
f'Backend (vllm) {error_type_string}: {response_json_body.get("error")}',
f'HTTP CODE {response_status_code}'
)
# TODO: how to detect an error?
# if backend_response == '':
# backend_err = True
# backend_response = format_sillytavern_err(
# f'Backend (vllm-gptq) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
# f'HTTP CODE {response_status_code}'
# )
log_prompt(client_ip, token, prompt, backend_response['choices'][0]['message']['content'], elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
return jsonify(backend_response), 200
log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
return jsonify({'results': [{'text': backend_response}]}), 200
else:
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code if response else None, is_error=True)
@ -44,13 +49,24 @@ class VLLMBackend(LLMBackend):
'results': [{'text': backend_response}]
}), 200
def validate_params(self, params_dict: dict):
try:
sampling_params = SamplingParams(**params_dict)
except ValueError as e:
print(e)
return False, e
return True, None
# def validate_params(self, params_dict: dict):
# default_params = SamplingParams()
# try:
# sampling_params = SamplingParams(
# temperature=params_dict.get('temperature', default_params.temperature),
# top_p=params_dict.get('top_p', default_params.top_p),
# top_k=params_dict.get('top_k', default_params.top_k),
# use_beam_search=True if params_dict['num_beams'] > 1 else False,
# length_penalty=params_dict.get('length_penalty', default_params.length_penalty),
# early_stopping=params_dict.get('early_stopping', default_params.early_stopping),
# stop=params_dict.get('stopping_strings', default_params.stop),
# ignore_eos=params_dict.get('ban_eos_token', False),
# max_tokens=params_dict.get('max_new_tokens', default_params.max_tokens)
# )
# except ValueError as e:
# print(e)
# return False, e
# return True, None
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# try:
@ -61,3 +77,33 @@ class VLLMBackend(LLMBackend):
# return r_json, None
# except Exception as e:
# return False, e
def get_parameters(self, parameters) -> Tuple[dict | None, Exception | None]:
default_params = SamplingParams()
try:
sampling_params = SamplingParams(
temperature=parameters.get('temperature', default_params.temperature),
top_p=parameters.get('top_p', default_params.top_p),
top_k=parameters.get('top_k', default_params.top_k),
use_beam_search=True if parameters['num_beams'] > 1 else False,
stop=parameters.get('stopping_strings', default_params.stop),
ignore_eos=parameters.get('ban_eos_token', False),
max_tokens=parameters.get('max_new_tokens', default_params.max_tokens)
)
except ValueError as e:
print(e)
return None, e
return vars(sampling_params), None
# def transform_sampling_params(params: SamplingParams):
# return {
# 'temperature': params['temperature'],
# 'top_p': params['top_p'],
# 'top_k': params['top_k'],
# 'use_beam_search' = True if parameters['num_beams'] > 1 else False,
# length_penalty = parameters.get('length_penalty', default_params.length_penalty),
# early_stopping = parameters.get('early_stopping', default_params.early_stopping),
# stop = parameters.get('stopping_strings', default_params.stop),
# ignore_eos = parameters.get('ban_eos_token', False),
# max_tokens = parameters.get('max_new_tokens', default_params.max_tokens)
# }

View File

@ -38,9 +38,9 @@ class OobaRequestHandler:
self.start_time = time.time()
self.client_ip = self.get_client_ip()
self.token = self.request.headers.get('X-Api-Key')
self.parameters = self.get_parameters()
self.priority = self.get_priority()
self.backend = self.get_backend()
self.parameters = self.parameters_invalid_msg = None
def validate_request(self) -> (bool, Union[str, None]):
# TODO: move this to LLMBackend
@ -56,19 +56,9 @@ class OobaRequestHandler:
else:
return self.request.remote_addr
def get_parameters(self):
# TODO: make this a LLMBackend method
request_valid_json, self.request_json_body = validate_json(self.request.data)
if not request_valid_json:
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
parameters = self.request_json_body.copy()
if opts.mode in ['oobabooga', 'hf-textgen']:
del parameters['prompt']
elif opts.mode == 'vllm':
parameters = delete_dict_key(parameters, ['messages', 'model', 'stream', 'logit_bias'])
else:
raise Exception
return parameters
# def get_parameters(self):
# # TODO: make this a LLMBackend method
# return self.backend.get_parameters()
def get_priority(self):
if self.token:
@ -91,24 +81,26 @@ class OobaRequestHandler:
else:
raise Exception
def get_parameters(self):
self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
def handle_request(self):
request_valid_json, self.request_json_body = validate_json(self.request.data)
if not request_valid_json:
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
self.get_parameters()
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
# Fix bug on text-generation-inference
# https://github.com/huggingface/text-generation-inference/issues/929
if opts.mode == 'hf-textgen' and self.parameters.get('typical_p', 0) > 0.998:
self.request_json_body['typical_p'] = 0.998
if opts.mode == 'vllm':
full_model_path = redis.get('full_model_path')
if not full_model_path:
raise Exception
self.request_json_body['model'] = full_model_path.decode()
request_valid, invalid_request_err_msg = self.validate_request()
params_valid, invalid_params_err_msg = self.backend.validate_params(self.parameters)
if not self.parameters:
params_valid = False
else:
params_valid = True
if not request_valid or not params_valid:
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, invalid_params_err_msg)] if not valid]
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid]
combined_error_message = ', '.join(error_messages)
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)
@ -119,21 +111,27 @@ class OobaRequestHandler:
'results': [{'text': err}]
}), 200
# Reconstruct the request JSON with the validated parameters and prompt.
prompt = self.request_json_body.get('prompt', '')
llm_request = {**self.parameters, 'prompt': prompt}
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0:
event = priority_queue.put((self.request_json_body, self.client_ip, self.token, self.parameters), self.priority)
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
else:
# Client was rate limited
event = None
if not event:
return self.handle_ratelimited()
event.wait()
success, response, error_msg = event.data
end_time = time.time()
elapsed_time = end_time - self.start_time
return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, self.request_json_body.get('prompt', ''), elapsed_time, self.parameters, dict(self.request.headers))
return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
def handle_ratelimited(self):
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')

View File

@ -7,14 +7,7 @@ from ... import opts
@bp.route('/generate', methods=['POST'])
@bp.route('/chat/completions', methods=['POST'])
def generate():
if opts.mode == 'vllm' and request.url.split('/')[-1] == 'generate':
return jsonify({
'code': 404,
'error': 'this LLM backend is in VLLM mode'
}), 404
request_valid_json, request_json_body = validate_json(request.data)
if not request_valid_json or not (request_json_body.get('prompt') or request_json_body.get('messages')):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400

View File

@ -82,6 +82,7 @@ def generate_stats():
'online': online,
'endpoints': {
'blocking': f'https://{opts.base_client_api}',
'streaming': f'wss://{opts.base_client_api}/stream',
},
'queue': {
'processing': active_gen_workers,
@ -104,9 +105,9 @@ def generate_stats():
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
}
if opts.mode in ['oobabooga', 'hf-textgen']:
output['endpoints']['streaming'] = f'wss://{opts.base_client_api}/v1/stream'
else:
output['endpoints']['streaming'] = None
# if opts.mode in ['oobabooga', 'hf-textgen']:
# output['endpoints']['streaming'] = f'wss://{opts.base_client_api}/v1/stream'
# else:
# output['endpoints']['streaming'] = None
return deep_sort(output)

View File

@ -4,9 +4,7 @@ from flask import jsonify, request
from . import bp
from ..cache import cache
from ... import opts
from ...llm.info import get_running_model
from ...llm.vllm.info import get_vlmm_models_info
# cache = Cache(bp, config={'CACHE_TYPE': 'simple'})
@ -20,16 +18,7 @@ from ...llm.vllm.info import get_vlmm_models_info
@bp.route('/model', methods=['GET'])
@bp.route('/models', methods=['GET'])
def get_model():
if opts.mode == 'vllm' and request.url.split('/')[-1] == 'model':
return jsonify({
'code': 404,
'error': 'this LLM backend is in VLLM mode'
}), 404
# We will manage caching ourself since we don't want to cache
# when the backend is down. Also, Cloudflare won't cache 500 errors.
cache_key = 'model_cache::' + request.url
@ -46,18 +35,10 @@ def get_model():
'type': error.__class__.__name__
}), 500 # return 500 so Cloudflare doesn't intercept us
else:
if opts.mode in ['oobabooga', 'hf-texgen']:
response = jsonify({
'result': model,
'timestamp': int(time.time())
}), 200
elif opts.mode == 'vllm':
response = jsonify({
**get_vlmm_models_info(),
'timestamp': int(time.time())
}), 200
else:
raise Exception
response = jsonify({
'result': model,
'timestamp': int(time.time())
}), 200
cache.set(cache_key, response, timeout=60)
return response

View File

@ -3,7 +3,6 @@ from threading import Thread
import requests
import llm_server
from llm_server import opts
from llm_server.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model
@ -25,16 +24,6 @@ class MainBackgroundThread(Thread):
redis.set('backend_online', 0)
redis.set_dict('backend_info', {})
if opts.mode == 'vllm':
while True:
try:
backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl)
r_json = backend_response.json()
redis.set('full_model_path', r_json['data'][0]['root'])
break
except Exception as e:
print(e)
def run(self):
while True:
if opts.mode == 'oobabooga':
@ -77,13 +66,13 @@ class MainBackgroundThread(Thread):
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model, exclude_zeros=True) or 0
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model, opts.mode, exclude_zeros=True) or 0
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', opts.running_model, exclude_zeros=True) or 0
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', opts.running_model, opts.mode, exclude_zeros=True) or 0
redis.set('average_output_tokens', average_output_tokens)
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)

70
other/vllm-gptq-setup.py Normal file
View File

@ -0,0 +1,70 @@
import io
import os
import re
from typing import List
import setuptools
from torch.utils.cpp_extension import BuildExtension
ROOT_DIR = os.path.dirname(__file__)
"""
Build vllm-gptq without any CUDA
"""
def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath)
def find_version(filepath: str):
"""Extract version information from the given filepath.
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
"""
with open(filepath) as fp:
version_match = re.search(
r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M)
if version_match:
return version_match.group(1)
raise RuntimeError("Unable to find version string.")
def read_readme() -> str:
"""Read the README file."""
return io.open(get_path("README.md"), "r", encoding="utf-8").read()
def get_requirements() -> List[str]:
"""Get Python package dependencies from requirements.txt."""
with open(get_path("requirements.txt")) as f:
requirements = f.read().strip().split("\n")
return requirements
setuptools.setup(
name="vllm-gptq",
version=find_version(get_path("vllm", "__init__.py")),
author="vLLM Team",
license="Apache 2.0",
description="A high-throughput and memory-efficient inference and serving engine for LLMs",
long_description=read_readme(),
long_description_content_type="text/markdown",
url="https://github.com/vllm-project/vllm",
project_urls={
"Homepage": "https://github.com/vllm-project/vllm",
"Documentation": "https://vllm.readthedocs.io/en/latest/",
},
classifiers=[
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"License :: OSI Approved :: Apache Software License",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
packages=setuptools.find_packages(
exclude=("assets", "benchmarks", "csrc", "docs", "examples", "tests")),
python_requires=">=3.8",
install_requires=get_requirements(),
cmdclass={"build_ext": BuildExtension},
)

94
other/vllm_api_server.py Normal file
View File

@ -0,0 +1,94 @@
import argparse
import json
import time
from pathlib import Path
from typing import AsyncGenerator
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
app = FastAPI()
served_model = None
@app.get("/model")
async def generate(request: Request) -> Response:
return JSONResponse({'model': served_model, 'timestamp': int(time.time())})
@app.post("/generate")
async def generate(request: Request) -> Response:
"""Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
text_outputs = [
prompt + output.text for output in request_output.outputs
]
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8")
async def abort_request() -> None:
await engine.abort(request_id)
if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(stream_results(), background=background_tasks)
# Non-streaming case
final_output = None
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
return Response(status_code=499)
final_output = request_output
assert final_output is not None
prompt = final_output.prompt
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
return JSONResponse(ret)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
served_model = Path(args.model).name
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
uvicorn.run(app,
host=args.host,
port=args.port,
log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@ -9,4 +9,4 @@ redis
gevent
async-timeout
flask-sock
vllm
auto_gptq

View File

@ -11,7 +11,7 @@ from llm_server import opts
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names
from llm_server.database import get_number_of_rows, init_db
from llm_server.helpers import resolve_path
from llm_server.llm.hf_textgen.info import hf_textget_info
from llm_server.llm.vllm.info import vllm_info
from llm_server.routes.cache import cache, redis
from llm_server.routes.queue import start_workers
from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time
@ -20,6 +20,13 @@ from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.stream import init_socketio
from llm_server.threads import MainBackgroundThread
try:
import vllm
except ModuleNotFoundError as e:
print('Could not import vllm-gptq:', e)
print('Please see vllm.md for install instructions')
sys.exit(1)
script_path = os.path.dirname(os.path.realpath(__file__))
config_path_environ = os.getenv("CONFIG_PATH")
@ -130,6 +137,10 @@ def home():
else:
info_html = ''
mode_info = ''
if opts.mode == 'vllm':
mode_info = vllm_info
return render_template('home.html',
llm_middleware_name=config['llm_middleware_name'],
analytics_tracking_code=analytics_tracking_code,
@ -143,7 +154,7 @@ def home():
streaming_input_textbox=mode_ui_names[opts.mode][2],
context_size=opts.context_size,
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
extra_info=hf_textget_info if opts.mode == 'hf-textgen' else '',
extra_info=mode_info,
)
@ -156,5 +167,11 @@ def fallback(first=None, rest=None):
}), 404
@app.errorhandler(500)
def server_error(e):
print(e)
return {'error': True}, 500
if __name__ == "__main__":
app.run(host='0.0.0.0')

4
vllm.md Normal file
View File

@ -0,0 +1,4 @@
```bash
wget https://git.evulid.cc/attachments/6e7bfc04-cad4-4494-a98d-1391fbb402d3 -O vllm-0.1.3-cp311-cp311-linux_x86_64.whl && pip install vllm-0.1.3-cp311-cp311-linux_x86_64.whl
pip install auto_gptq
```