implement vllm backend

This commit is contained in:
Cyberes 2023-09-11 20:47:19 -06:00
parent c14cc51f09
commit 4c9d543eab
21 changed files with 358 additions and 33 deletions

View File

@ -9,7 +9,7 @@ token_limit: 8192
# How many requests a single IP is allowed to put in the queue.
# If an IP tries to put more than this their request will be rejected
# until the other(s) are completed.
ip_in_queue_max: 2
simultaneous_requests_per_ip: 2
## Optional

View File

@ -13,7 +13,7 @@ config_default_vars = {
'average_generation_time_mode': 'database',
'info_html': None,
'show_total_output_tokens': True,
'ip_in_queue_max': 3,
'simultaneous_requests_per_ip': 3,
'show_backend_info': True,
'max_new_tokens': 500
}
@ -22,6 +22,7 @@ config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middlewar
mode_ui_names = {
'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
'hf-textgen': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
'vllm': ('Chat Completion', 'Reverse Proxy', 'N/A'),
}

View File

@ -8,5 +8,9 @@ def generator(request_json_body):
elif opts.mode == 'hf-textgen':
from .hf_textgen.generate import generate
return generate(request_json_body)
elif opts.mode == 'vllm':
from .vllm.generate import generate
r = generate(request_json_body)
return r
else:
raise Exception

View File

@ -1,7 +1,9 @@
import sys
from typing import Tuple
import requests
from flask import jsonify
from llm_server import opts
from llm_server.database import log_prompt
from llm_server.helpers import indefinite_article
from llm_server.llm.llm_backend import LLMBackend
@ -47,3 +49,11 @@ class HfTextgenLLMBackend(LLMBackend):
if params_dict.get('typical_p', 0) > 0.998:
return False, '`typical_p` must be less than 0.999'
return True, None
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# try:
# backend_response = requests.get(f'{opts.backend_url}/info', verify=opts.verify_ssl)
# r_json = backend_response.json()
# return r_json['model_id'].replace('/', '_'), None
# except Exception as e:
# return False, e

View File

@ -1,9 +1,11 @@
import requests
from llm_server import opts
from pathlib import Path
def get_running_model():
# TODO: cache the results for 1 min so we don't have to keep calling the backend
if opts.mode == 'oobabooga':
try:
backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=3, verify=opts.verify_ssl)
@ -18,5 +20,14 @@ def get_running_model():
return r_json['model_id'].replace('/', '_'), None
except Exception as e:
return False, e
elif opts.mode == 'vllm':
try:
backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl)
r_json = backend_response.json()
model_name = Path(r_json['data'][0]['root']).name
# r_json['data'][0]['root'] = model_name
return model_name, None
except Exception as e:
return False, e
else:
raise Exception

View File

@ -1,9 +1,12 @@
from typing import Union
from typing import Union, Tuple
class LLMBackend:
def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
raise NotImplementedError
def validate_params(self, params_dict: dict) -> (bool, Union[str, None]):
def validate_params(self, params_dict: dict) -> Tuple[bool, str | None]:
raise NotImplementedError
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# raise NotImplementedError

View File

@ -1,5 +1,9 @@
from typing import Tuple
import requests
from flask import jsonify
from ... import opts
from ...database import log_prompt
from ...helpers import safe_list_get
from ...routes.cache import redis
@ -59,3 +63,11 @@ class OobaboogaLLMBackend(LLMBackend):
def validate_params(self, params_dict: dict):
# No validation required
return True, None
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# try:
# backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=3, verify=opts.verify_ssl)
# r_json = backend_response.json()
# return r_json['result'], None
# except Exception as e:
# return False, e

View File

View File

@ -0,0 +1,116 @@
"""
This file is used by the worker that processes requests.
"""
import io
import json
import time
from uuid import uuid4
import requests
from requests import Response
from llm_server import opts
from llm_server.database import tokenizer
from llm_server.routes.cache import redis
# TODO: make the VLMM backend return TPS and time elapsed
# https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py
def prepare_json(json_data: dict):
# logit_bias is not currently supported
del json_data['logit_bias']
return json_data
def transform_to_text(json_request, api_response):
"""
This is to convert a streaming request to a non-streamed request. Don't think this is nessesary.
:param json_request:
:param api_response:
:return:
"""
prompt = transform_prompt_to_text(json_request['messages'])
text = ''
finish_reason = None
for line in api_response.split('\n'):
if line.startswith('data:'):
try:
data = json.loads(line[5:].strip())
except json.decoder.JSONDecodeError:
break
print(data)
if 'choices' in data:
for choice in data['choices']:
if 'delta' in choice and 'content' in choice['delta']:
text += choice['delta']['content']
if data['choices'][0]['finish_reason']:
finish_reason = data['choices'][0]['finish_reason']
prompt_tokens = len(tokenizer.encode(prompt))
completion_tokens = len(tokenizer.encode(text))
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
return {
"id": str(uuid4()),
"object": "chat.completion",
"created": int(time.time()),
"model": opts.running_model,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
},
"choices": [
{
"message": {
"role": "assistant",
"content": text
},
"finish_reason": finish_reason,
"index": 0
}
]
}
def transform_prompt_to_text(prompt: list):
text = ''
for item in prompt:
text += item['content'] + '\n'
return text.strip('\n')
def handle_blocking_request(json_data: dict):
try:
r = requests.post(f'{opts.backend_url}/v1/chat/completions', json=prepare_json(json_data), verify=opts.verify_ssl)
except Exception as e:
return False, None, f'{e.__class__.__name__}: {e}'
# TODO: check for error here?
response_json = r.json()
response_json['error'] = False
new_response = Response()
new_response.status_code = r.status_code
new_response._content = json.dumps(response_json).encode('utf-8')
new_response.raw = io.BytesIO(new_response._content)
new_response.headers = r.headers
new_response.url = r.url
new_response.reason = r.reason
new_response.cookies = r.cookies
new_response.elapsed = r.elapsed
new_response.request = r.request
return True, new_response, None
def generate(json_data: dict):
full_model_path = redis.get('full_model_path')
if not full_model_path:
raise Exception
json_data['model'] = full_model_path.decode()
if json_data.get('stream'):
raise Exception('streaming not implemented')
else:
return handle_blocking_request(json_data)

View File

@ -0,0 +1,13 @@
from pathlib import Path
import requests
from llm_server import opts
def get_vlmm_models_info():
backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl)
r_json = backend_response.json()
r_json['data'][0]['root'] = Path(r_json['data'][0]['root']).name
r_json['data'][0]['id'] = Path(r_json['data'][0]['id']).name
return r_json

View File

@ -0,0 +1,63 @@
from flask import jsonify
from vllm import SamplingParams
from llm_server.database import log_prompt
from llm_server.helpers import indefinite_article
from llm_server.llm.llm_backend import LLMBackend
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.helpers.http import validate_json
# https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py
# TODO: https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/proxy/middleware/common.ts?ref_type=heads#L69
class VLLMBackend(LLMBackend):
def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
response_valid_json, response_json_body = validate_json(response)
backend_err = False
try:
response_status_code = response.status_code
except:
response_status_code = 0
if response_valid_json:
backend_response = response_json_body
if response_json_body.get('error'):
backend_err = True
error_type = response_json_body.get('error_type')
error_type_string = f'returned {indefinite_article(error_type)} {error_type} error'
backend_response = format_sillytavern_err(
f'Backend (vllm) {error_type_string}: {response_json_body.get("error")}',
f'HTTP CODE {response_status_code}'
)
log_prompt(client_ip, token, prompt, backend_response['choices'][0]['message']['content'], elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
return jsonify(backend_response), 200
else:
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code if response else None, is_error=True)
return jsonify({
'code': 500,
'msg': 'the backend did not return valid JSON',
'results': [{'text': backend_response}]
}), 200
def validate_params(self, params_dict: dict):
try:
sampling_params = SamplingParams(**params_dict)
except ValueError as e:
print(e)
return False, e
return True, None
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# try:
# backend_response = requests.get(f'{opts.backend_url}/api/v1/models', timeout=3, verify=opts.verify_ssl)
# r_json = backend_response.json()
# model_path = Path(r_json['data'][0]['root']).name
# r_json['data'][0]['root'] = model_path
# return r_json, None
# except Exception as e:
# return False, e

View File

@ -20,5 +20,5 @@ show_uptime = True
average_generation_time_mode = 'database'
show_total_output_tokens = True
netdata_root = None
ip_in_queue_max = 3
simultaneous_requests_per_ip = 3
show_backend_info = True

View File

@ -40,7 +40,7 @@ class PriorityQueue:
with self._cv:
# Check if the IP is already in the dictionary and if it has reached the limit
ip_count = redis.get_dict('queued_ip_count')
if item[1] in ip_count and ip_count[item[1]] >= opts.ip_in_queue_max and priority != 0:
if item[1] in ip_count and ip_count[item[1]] >= opts.simultaneous_requests_per_ip and priority != 0:
return None # reject the request
heapq.heappush(self._queue, (-priority, self._index, item, event))
self._index += 1

View File

@ -8,6 +8,7 @@ from llm_server import opts
from llm_server.database import log_prompt
from llm_server.llm.hf_textgen.hf_textgen_backend import HfTextgenLLMBackend
from llm_server.llm.oobabooga.ooba_backend import OobaboogaLLMBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.cache import redis
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.helpers.http import validate_json
@ -17,6 +18,19 @@ from llm_server.routes.stats import SemaphoreCheckerThread
DEFAULT_PRIORITY = 9999
def delete_dict_key(d: dict, k: Union[str, list]):
if isinstance(k, str):
if k in d.keys():
del d[k]
elif isinstance(k, list):
for item in k:
if item in d.keys():
del d[item]
else:
raise ValueError
return d
class OobaRequestHandler:
def __init__(self, incoming_request):
self.request_json_body = None
@ -29,7 +43,8 @@ class OobaRequestHandler:
self.backend = self.get_backend()
def validate_request(self) -> (bool, Union[str, None]):
if self.parameters.get('max_new_tokens', 0) > opts.max_new_tokens:
# TODO: move this to LLMBackend
if self.parameters.get('max_new_tokens', 0) > opts.max_new_tokens or self.parameters.get('max_tokens', 0) > opts.max_new_tokens:
return False, f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}'
return True, None
@ -42,11 +57,17 @@ class OobaRequestHandler:
return self.request.remote_addr
def get_parameters(self):
# TODO: make this a LLMBackend method
request_valid_json, self.request_json_body = validate_json(self.request.data)
if not request_valid_json:
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
parameters = self.request_json_body.copy()
del parameters['prompt']
if opts.mode in ['oobabooga', 'hf-textgen']:
del parameters['prompt']
elif opts.mode == 'vllm':
parameters = delete_dict_key(parameters, ['messages', 'model', 'stream', 'logit_bias'])
else:
raise Exception
return parameters
def get_priority(self):
@ -65,6 +86,8 @@ class OobaRequestHandler:
return OobaboogaLLMBackend()
elif opts.mode == 'hf-textgen':
return HfTextgenLLMBackend()
elif opts.mode == 'vllm':
return VLLMBackend()
else:
raise Exception
@ -83,6 +106,7 @@ class OobaRequestHandler:
combined_error_message = ', '.join(error_messages)
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
return jsonify({
'code': 400,
'msg': 'parameter validation error',
@ -90,7 +114,7 @@ class OobaRequestHandler:
}), 200
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
if queued_ip_count < opts.ip_in_queue_max or self.priority == 0:
if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0:
event = priority_queue.put((self.request_json_body, self.client_ip, self.token, self.parameters), self.priority)
else:
# Client was rate limited
@ -106,7 +130,7 @@ class OobaRequestHandler:
return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, self.request_json_body.get('prompt', ''), elapsed_time, self.parameters, dict(self.request.headers))
def handle_ratelimited(self):
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.ip_in_queue_max} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, is_error=True)
return jsonify({
'results': [{'text': backend_response}]

View File

@ -3,12 +3,20 @@ from flask import jsonify, request
from . import bp
from ..helpers.http import validate_json
from ..request_handler import OobaRequestHandler
from ... import opts
@bp.route('/generate', methods=['POST'])
@bp.route('/chat/completions', methods=['POST'])
def generate():
if opts.mode == 'vllm' and request.url.split('/')[-1] == 'generate':
return jsonify({
'code': 404,
'error': 'this LLM backend is in VLLM mode'
}), 404
request_valid_json, request_json_body = validate_json(request.data)
if not request_valid_json or not request_json_body.get('prompt'):
if not request_valid_json or not (request_json_body.get('prompt') or request_json_body.get('messages')):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else:
handler = OobaRequestHandler(request)

View File

@ -82,7 +82,6 @@ def generate_stats():
'online': online,
'endpoints': {
'blocking': f'https://{opts.base_client_api}',
'streaming': f'wss://{opts.base_client_api}/v1/stream',
},
'queue': {
'processing': active_gen_workers,
@ -96,7 +95,7 @@ def generate_stats():
'concurrent': opts.concurrent_gens,
'model': model_name,
'mode': opts.mode,
'simultaneous_requests_per_ip': opts.ip_in_queue_max,
'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip,
},
'keys': {
'openaiKeys': '',
@ -104,4 +103,10 @@ def generate_stats():
},
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
}
if opts.mode in ['oobabooga', 'hf-textgen']:
output['endpoints']['streaming'] = f'wss://{opts.base_client_api}/v1/stream'
else:
output['endpoints']['streaming'] = None
return deep_sort(output)

View File

@ -3,10 +3,10 @@ import time
from flask import jsonify, request
from . import bp
from ..helpers.http import cache_control
from ..cache import cache
from ... import opts
from ...llm.info import get_running_model
from ..cache import cache
from ...llm.vllm.info import get_vlmm_models_info
# cache = Cache(bp, config={'CACHE_TYPE': 'simple'})
@ -20,7 +20,16 @@ from ..cache import cache
@bp.route('/model', methods=['GET'])
@bp.route('/models', methods=['GET'])
def get_model():
if opts.mode == 'vllm' and request.url.split('/')[-1] == 'model':
return jsonify({
'code': 404,
'error': 'this LLM backend is in VLLM mode'
}), 404
# We will manage caching ourself since we don't want to cache
# when the backend is down. Also, Cloudflare won't cache 500 errors.
cache_key = 'model_cache::' + request.url
@ -37,10 +46,18 @@ def get_model():
'type': error.__class__.__name__
}), 500 # return 500 so Cloudflare doesn't intercept us
else:
response = jsonify({
'result': model,
'timestamp': int(time.time())
}), 200
if opts.mode in ['oobabooga', 'hf-texgen']:
response = jsonify({
'result': model,
'timestamp': int(time.time())
}), 200
elif opts.mode == 'vllm':
response = jsonify({
**get_vlmm_models_info(),
'timestamp': int(time.time())
}), 200
else:
raise Exception
cache.set(cache_key, response, timeout=60)
return response

View File

@ -3,8 +3,10 @@ from threading import Thread
import requests
import llm_server
from llm_server import opts
from llm_server.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model
from llm_server.routes.cache import redis
@ -23,17 +25,34 @@ class MainBackgroundThread(Thread):
redis.set('backend_online', 0)
redis.set_dict('backend_info', {})
if opts.mode == 'vllm':
while True:
try:
backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl)
r_json = backend_response.json()
redis.set('full_model_path', r_json['data'][0]['root'])
break
except Exception as e:
print(e)
def run(self):
while True:
if opts.mode == 'oobabooga':
try:
r = requests.get(f'{opts.backend_url}/api/v1/model', timeout=3, verify=opts.verify_ssl)
opts.running_model = r.json()['result']
redis.set('backend_online', 1)
except Exception as e:
# try:
# r = requests.get(f'{opts.backend_url}/api/v1/model', timeout=3, verify=opts.verify_ssl)
# opts.running_model = r.json()['result']
# redis.set('backend_online', 1)
# except Exception as e:
# redis.set('backend_online', 0)
# # TODO: handle error
# print(e)
model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
# TODO: handle error
print(e)
else:
opts.running_model = model
redis.set('backend_online', 1)
elif opts.mode == 'hf-textgen':
try:
r = requests.get(f'{opts.backend_url}/info', timeout=3, verify=opts.verify_ssl)
@ -45,6 +64,14 @@ class MainBackgroundThread(Thread):
redis.set('backend_online', 0)
# TODO: handle error
print(e)
elif opts.mode == 'vllm':
model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
opts.running_model = model
redis.set('backend_online', 1)
else:
raise Exception

View File

@ -9,3 +9,4 @@ redis
gevent
async-timeout
flask-sock
vllm

View File

@ -41,7 +41,7 @@ if config['database_path'].startswith('./'):
opts.database_path = resolve_path(config['database_path'])
init_db()
if config['mode'] not in ['oobabooga', 'hf-textgen']:
if config['mode'] not in ['oobabooga', 'hf-textgen', 'vllm']:
print('Unknown mode:', config['mode'])
sys.exit(1)
opts.mode = config['mode']
@ -55,7 +55,7 @@ opts.show_uptime = config['show_uptime']
opts.backend_url = config['backend_url'].strip('/')
opts.show_total_output_tokens = config['show_total_output_tokens']
opts.netdata_root = config['netdata_root']
opts.ip_in_queue_max = config['ip_in_queue_max']
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
opts.show_backend_info = config['show_backend_info']
opts.max_new_tokens = config['max_new_tokens']
@ -135,8 +135,8 @@ def home():
analytics_tracking_code=analytics_tracking_code,
info_html=info_html,
current_model=running_model,
client_api=f'https://{opts.base_client_api}',
ws_client_api=f'wss://{opts.base_client_api}/v1/stream',
client_api=stats['endpoints']['blocking'],
ws_client_api=stats['endpoints']['streaming'],
estimated_wait=estimated_wait_sec,
mode_name=mode_ui_names[opts.mode][0],
api_input_textbox=mode_ui_names[opts.mode][1],

View File

@ -61,6 +61,10 @@
font-size: 1.5em;
}
}
.hidden {
display: none;
}
</style>
</head>
@ -83,11 +87,17 @@
<strong>Instructions:</strong>
<ol>
<li>Set your API type to <kbd>{{ mode_name }}</kbd></li>
{% if not ws_client_api %}
<li>Set <kbd>Chat Completion Source</kbd> to <kbd>OpenAI</kbd>.</li>
{% endif %}
<li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li>
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.</li>
{% if ws_client_api %}
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.
</li>
<li>If using a token, check the <kbd>Mancer AI</kbd> checkbox and enter your token in the <kbd>Mancer
API key</kbd> textbox.
</li>
{% endif %}
<li>Click <kbd>Connect</kbd> to test the connection.</li>
<li>Open your preset config and set <kbd>Context Size</kbd> to {{ context_size }}.</li>
<li>Follow this guide to get set up: <a href="https://rentry.org/freellamas" target="_blank">rentry.org/freellamas</a>