set up queue to work with gunicorn processes, other improvements

This commit is contained in:
Cyberes 2023-09-14 17:38:20 -06:00
parent 5d03f875cb
commit 3100b0a924
9 changed files with 105 additions and 83 deletions

View File

@ -4,7 +4,7 @@ import flask
class LLMBackend:
default_params: dict
_default_params: dict
def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers):
raise NotImplementedError

View File

@ -10,7 +10,7 @@ from llm_server.routes.helpers.http import validate_json
class VLLMBackend(LLMBackend):
default_params = vars(SamplingParams())
_default_params = vars(SamplingParams())
def handle_response(self, success, request, response_json_body, response_status_code, client_ip, token, prompt: str, elapsed_time, parameters, headers):
if len(response_json_body.get('text', [])):
@ -25,14 +25,18 @@ class VLLMBackend(LLMBackend):
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
try:
# top_k == -1 means disabled
top_k = parameters.get('top_k', self._default_params['top_k'])
if top_k <= 0:
top_k = -1
sampling_params = SamplingParams(
temperature=parameters.get('temperature', self.default_params['temperature']),
top_p=parameters.get('top_p', self.default_params['top_p']),
top_k=parameters.get('top_k', self.default_params['top_k']),
temperature=parameters.get('temperature', self._default_params['temperature']),
top_p=parameters.get('top_p', self._default_params['top_p']),
top_k=top_k,
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
stop=parameters.get('stopping_strings', self.default_params['stop']),
stop=parameters.get('stopping_strings', self._default_params['stop']),
ignore_eos=parameters.get('ban_eos_token', False),
max_tokens=parameters.get('max_new_tokens', self.default_params['max_tokens'])
max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens'])
)
except ValueError as e:
return None, str(e).strip('.')

View File

@ -30,4 +30,4 @@ expose_openai_system_prompt = True
enable_streaming = True
openai_api_key = None
backend_request_timeout = 30
backend_generate_request_timeout = 120
backend_generate_request_timeout = 95

View File

@ -1,12 +1,11 @@
import time
from typing import Tuple
import flask
from flask import jsonify
from llm_server import opts
from llm_server.database import log_prompt
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.helpers.http import validate_json
from llm_server.routes.queue import priority_queue
from llm_server.routes.request_handler import RequestHandler
@ -35,3 +34,8 @@ class OobaRequestHandler(RequestHandler):
return jsonify({
'results': [{'text': backend_response}]
}), 200
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
return jsonify({
'results': [{'text': msg}]
}), 200

View File

@ -10,7 +10,6 @@ from ..openai_request_handler import OpenAIRequestHandler, build_openai_response
@openai_bp.route('/chat/completions', methods=['POST'])
def openai_chat_completions():
# TODO: make this work with oobabooga
request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('messages'):
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400

View File

@ -48,9 +48,11 @@ class OpenAIRequestHandler(RequestHandler):
# Reconstruct the request JSON with the validated parameters and prompt.
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
llm_request = {**self.parameters, 'prompt': self.prompt}
_, (backend_response, backend_response_status_code) = self.generate_response(llm_request)
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
if success:
return build_openai_response(self.prompt, backend_response.json['results'][0]['text']), backend_response_status_code
else:
return backend_response, backend_response_status_code
def handle_ratelimited(self):
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
@ -81,13 +83,16 @@ class OpenAIRequestHandler(RequestHandler):
prompt += '\n\n### RESPONSE: '
return prompt
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
return build_openai_response('', msg), 200
def check_moderation_endpoint(prompt: str):
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {opts.openai_api_key}",
}
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}).json()
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10).json()
offending_categories = []
for k, v in response['results'][0]['categories'].items():
if v:

View File

@ -1,6 +1,10 @@
import heapq
import json
import pickle
import threading
import time
from uuid import uuid4
from redis import Redis
from llm_server import opts
from llm_server.llm.generator import generator
@ -27,58 +31,77 @@ def decrement_ip_count(client_ip: int, redis_key):
return ip_count
class PriorityQueue:
class RedisPriorityQueue:
def __init__(self):
self._queue = []
self._index = 0
self._cv = threading.Condition()
self._lock = threading.Lock()
redis.set_dict('queued_ip_count', {})
self.redis = Redis(host='localhost', port=6379, db=15)
# Clear the DB
for key in self.redis.scan_iter('*'):
self.redis.delete(key)
self.pubsub = self.redis.pubsub()
self.pubsub.subscribe('events')
def put(self, item, priority):
event = DataEvent()
with self._cv:
# Check if the IP is already in the dictionary and if it has reached the limit
ip_count = redis.get_dict('queued_ip_count')
if item[1] in ip_count and ip_count[item[1]] >= opts.simultaneous_requests_per_ip and priority != 0:
ip_count = self.redis.hget('queued_ip_count', item[1])
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
return None # reject the request
heapq.heappush(self._queue, (-priority, self._index, item, event))
self.redis.zadd('queue', {json.dumps((self._index, item, event.event_id)): -priority})
self._index += 1
# Increment the count for this IP
with self._lock:
increment_ip_count(item[1], 'queued_ip_count')
self._cv.notify()
self.increment_ip_count(item[1], 'queued_ip_count')
return event
def get(self):
with self._cv:
while len(self._queue) == 0:
self._cv.wait()
_, _, item, event = heapq.heappop(self._queue)
while True:
data = self.redis.zpopmin('queue')
if data:
item = json.loads(data[0][0])
client_ip = item[1][1]
# Decrement the count for this IP
with self._lock:
decrement_ip_count(item[1], 'queued_ip_count')
return item, event
self.decrement_ip_count(client_ip, 'queued_ip_count')
return item
time.sleep(1) # wait for an item to be added to the queue
def increment_ip_count(self, ip, key):
self.redis.hincrby(key, ip, 1)
def decrement_ip_count(self, ip, key):
self.redis.hincrby(key, ip, -1)
def __len__(self):
return len(self._queue)
return self.redis.zcard('queue')
priority_queue = PriorityQueue()
class DataEvent:
def __init__(self, event_id=None):
self.event_id = event_id if event_id else str(uuid4())
self.redis = Redis(host='localhost', port=6379, db=14)
self.pubsub = self.redis.pubsub()
self.pubsub.subscribe(self.event_id)
def set(self, data):
self.redis.publish(self.event_id, pickle.dumps(data))
def wait(self):
for item in self.pubsub.listen():
if item['type'] == 'message':
return pickle.loads(item['data'])
class DataEvent(threading.Event):
def __init__(self):
super().__init__()
self.data = None
priority_queue = RedisPriorityQueue()
def worker():
global processing_ips_lock
while True:
(request_json_body, client_ip, token, parameters), event = priority_queue.get()
index, (request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
# redis.sadd('processing_ips', client_ip)
increment_ip_count(client_ip, 'processing_ips')
redis.incr('active_gen_workers')
@ -91,10 +114,9 @@ def worker():
with generation_elapsed_lock:
generation_elapsed.append((end_time, elapsed_time))
event.data = (success, response, error_msg)
event.set()
event = DataEvent(event_id)
event.set((success, response, error_msg))
# redis.srem('processing_ips', client_ip)
decrement_ip_count(client_ip, 'processing_ips')
redis.decr('active_gen_workers')

View File

@ -3,7 +3,7 @@ import time
from typing import Tuple, Union
import flask
from flask import Response, jsonify
from flask import Response
from llm_server import opts
from llm_server.database import log_prompt
@ -27,7 +27,7 @@ class RequestHandler:
self.token = self.request.headers.get('X-Api-Key')
self.priority = self.get_priority()
self.backend = get_backend()
self.parameters = self.parameters_invalid_msg = None
self.parameters = None
self.used = False
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
@ -50,31 +50,26 @@ class RequestHandler:
return result[0]
return DEFAULT_PRIORITY
def load_parameters(self):
# Handle OpenAI
def get_parameters(self):
if self.request_json_body.get('max_tokens'):
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
return parameters, parameters_invalid_msg
def validate_request(self) -> Tuple[bool, Tuple[Response | None, int]]:
self.load_parameters()
params_valid = False
self.parameters, parameters_invalid_msg = self.get_parameters()
request_valid = False
invalid_request_err_msg = None
if self.parameters:
params_valid = True
request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters)
if not request_valid or not params_valid:
error_messages = [msg for valid, msg in [request_valid, params_valid] if not valid and msg]
if not request_valid:
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (not bool(parameters_invalid_msg), parameters_invalid_msg)] if not valid and msg]
combined_error_message = ', '.join(error_messages)
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
backend_response = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
return False, (jsonify({
'code': 400,
'msg': 'parameter validation error',
'results': [{'text': err}]
}), 200)
return False, self.handle_error(backend_response)
return True, (None, 0)
def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]:
@ -88,9 +83,7 @@ class RequestHandler:
prompt = llm_request['prompt']
event.wait()
success, response, error_msg = event.data
success, response, error_msg = event.wait()
end_time = time.time()
elapsed_time = end_time - self.start_time
@ -113,11 +106,7 @@ class RequestHandler:
error_msg = error_msg.strip('.') + '.'
backend_response = format_sillytavern_err(error_msg, 'error')
log_prompt(self.client_ip, self.token, prompt, backend_response, None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
return (False, None, None, 0), (jsonify({
'code': 500,
'msg': error_msg,
'results': [{'text': backend_response}]
}), 200)
return (False, None, None, 0), self.handle_error(backend_response)
# ===============================================
@ -137,11 +126,7 @@ class RequestHandler:
error_msg = 'The backend did not return valid JSON.'
backend_response = format_sillytavern_err(error_msg, 'error')
log_prompt(self.client_ip, self.token, prompt, backend_response, elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
return (False, None, None, 0), (jsonify({
'code': 500,
'msg': error_msg,
'results': [{'text': backend_response}]
}), 200)
return (False, None, None, 0), self.handle_error(backend_response)
# ===============================================
@ -164,6 +149,9 @@ class RequestHandler:
def handle_ratelimited(self) -> Tuple[flask.Response, int]:
raise NotImplementedError
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
raise NotImplementedError
def get_backend():
if opts.mode == 'oobabooga':

View File

@ -186,4 +186,4 @@ def server_error(e):
if __name__ == "__main__":
app.run(host='0.0.0.0')
app.run(host='0.0.0.0', threaded=False, processes=15)