set up queue to work with gunicorn processes, other improvements

This commit is contained in:
Cyberes 2023-09-14 17:38:20 -06:00
parent 5d03f875cb
commit 3100b0a924
9 changed files with 105 additions and 83 deletions

View File

@ -4,7 +4,7 @@ import flask
class LLMBackend: class LLMBackend:
default_params: dict _default_params: dict
def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers): def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers):
raise NotImplementedError raise NotImplementedError

View File

@ -10,7 +10,7 @@ from llm_server.routes.helpers.http import validate_json
class VLLMBackend(LLMBackend): class VLLMBackend(LLMBackend):
default_params = vars(SamplingParams()) _default_params = vars(SamplingParams())
def handle_response(self, success, request, response_json_body, response_status_code, client_ip, token, prompt: str, elapsed_time, parameters, headers): def handle_response(self, success, request, response_json_body, response_status_code, client_ip, token, prompt: str, elapsed_time, parameters, headers):
if len(response_json_body.get('text', [])): if len(response_json_body.get('text', [])):
@ -25,14 +25,18 @@ class VLLMBackend(LLMBackend):
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]: def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
try: try:
# top_k == -1 means disabled
top_k = parameters.get('top_k', self._default_params['top_k'])
if top_k <= 0:
top_k = -1
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=parameters.get('temperature', self.default_params['temperature']), temperature=parameters.get('temperature', self._default_params['temperature']),
top_p=parameters.get('top_p', self.default_params['top_p']), top_p=parameters.get('top_p', self._default_params['top_p']),
top_k=parameters.get('top_k', self.default_params['top_k']), top_k=top_k,
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False, use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
stop=parameters.get('stopping_strings', self.default_params['stop']), stop=parameters.get('stopping_strings', self._default_params['stop']),
ignore_eos=parameters.get('ban_eos_token', False), ignore_eos=parameters.get('ban_eos_token', False),
max_tokens=parameters.get('max_new_tokens', self.default_params['max_tokens']) max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens'])
) )
except ValueError as e: except ValueError as e:
return None, str(e).strip('.') return None, str(e).strip('.')

View File

@ -30,4 +30,4 @@ expose_openai_system_prompt = True
enable_streaming = True enable_streaming = True
openai_api_key = None openai_api_key = None
backend_request_timeout = 30 backend_request_timeout = 30
backend_generate_request_timeout = 120 backend_generate_request_timeout = 95

View File

@ -1,12 +1,11 @@
import time from typing import Tuple
import flask
from flask import jsonify from flask import jsonify
from llm_server import opts from llm_server import opts
from llm_server.database import log_prompt from llm_server.database import log_prompt
from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.helpers.http import validate_json
from llm_server.routes.queue import priority_queue
from llm_server.routes.request_handler import RequestHandler from llm_server.routes.request_handler import RequestHandler
@ -35,3 +34,8 @@ class OobaRequestHandler(RequestHandler):
return jsonify({ return jsonify({
'results': [{'text': backend_response}] 'results': [{'text': backend_response}]
}), 200 }), 200
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
return jsonify({
'results': [{'text': msg}]
}), 200

View File

@ -10,7 +10,6 @@ from ..openai_request_handler import OpenAIRequestHandler, build_openai_response
@openai_bp.route('/chat/completions', methods=['POST']) @openai_bp.route('/chat/completions', methods=['POST'])
def openai_chat_completions(): def openai_chat_completions():
# TODO: make this work with oobabooga
request_valid_json, request_json_body = validate_json(request) request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('messages'): if not request_valid_json or not request_json_body.get('messages'):
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400

View File

@ -48,9 +48,11 @@ class OpenAIRequestHandler(RequestHandler):
# Reconstruct the request JSON with the validated parameters and prompt. # Reconstruct the request JSON with the validated parameters and prompt.
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE']) self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
llm_request = {**self.parameters, 'prompt': self.prompt} llm_request = {**self.parameters, 'prompt': self.prompt}
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
_, (backend_response, backend_response_status_code) = self.generate_response(llm_request) if success:
return build_openai_response(self.prompt, backend_response.json['results'][0]['text']), backend_response_status_code return build_openai_response(self.prompt, backend_response.json['results'][0]['text']), backend_response_status_code
else:
return backend_response, backend_response_status_code
def handle_ratelimited(self): def handle_ratelimited(self):
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
@ -81,13 +83,16 @@ class OpenAIRequestHandler(RequestHandler):
prompt += '\n\n### RESPONSE: ' prompt += '\n\n### RESPONSE: '
return prompt return prompt
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
return build_openai_response('', msg), 200
def check_moderation_endpoint(prompt: str): def check_moderation_endpoint(prompt: str):
headers = { headers = {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'Authorization': f"Bearer {opts.openai_api_key}", 'Authorization': f"Bearer {opts.openai_api_key}",
} }
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}).json() response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10).json()
offending_categories = [] offending_categories = []
for k, v in response['results'][0]['categories'].items(): for k, v in response['results'][0]['categories'].items():
if v: if v:

View File

@ -1,6 +1,10 @@
import heapq import json
import pickle
import threading import threading
import time import time
from uuid import uuid4
from redis import Redis
from llm_server import opts from llm_server import opts
from llm_server.llm.generator import generator from llm_server.llm.generator import generator
@ -27,58 +31,77 @@ def decrement_ip_count(client_ip: int, redis_key):
return ip_count return ip_count
class PriorityQueue: class RedisPriorityQueue:
def __init__(self): def __init__(self):
self._queue = []
self._index = 0 self._index = 0
self._cv = threading.Condition()
self._lock = threading.Lock() self._lock = threading.Lock()
redis.set_dict('queued_ip_count', {}) self.redis = Redis(host='localhost', port=6379, db=15)
# Clear the DB
for key in self.redis.scan_iter('*'):
self.redis.delete(key)
self.pubsub = self.redis.pubsub()
self.pubsub.subscribe('events')
def put(self, item, priority): def put(self, item, priority):
event = DataEvent() event = DataEvent()
with self._cv: # Check if the IP is already in the dictionary and if it has reached the limit
# Check if the IP is already in the dictionary and if it has reached the limit ip_count = self.redis.hget('queued_ip_count', item[1])
ip_count = redis.get_dict('queued_ip_count') if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
if item[1] in ip_count and ip_count[item[1]] >= opts.simultaneous_requests_per_ip and priority != 0: return None # reject the request
return None # reject the request self.redis.zadd('queue', {json.dumps((self._index, item, event.event_id)): -priority})
heapq.heappush(self._queue, (-priority, self._index, item, event)) self._index += 1
self._index += 1 # Increment the count for this IP
# Increment the count for this IP with self._lock:
with self._lock: self.increment_ip_count(item[1], 'queued_ip_count')
increment_ip_count(item[1], 'queued_ip_count')
self._cv.notify()
return event return event
def get(self): def get(self):
with self._cv: while True:
while len(self._queue) == 0: data = self.redis.zpopmin('queue')
self._cv.wait() if data:
_, _, item, event = heapq.heappop(self._queue) item = json.loads(data[0][0])
# Decrement the count for this IP client_ip = item[1][1]
with self._lock: # Decrement the count for this IP
decrement_ip_count(item[1], 'queued_ip_count') with self._lock:
return item, event self.decrement_ip_count(client_ip, 'queued_ip_count')
return item
time.sleep(1) # wait for an item to be added to the queue
def increment_ip_count(self, ip, key):
self.redis.hincrby(key, ip, 1)
def decrement_ip_count(self, ip, key):
self.redis.hincrby(key, ip, -1)
def __len__(self): def __len__(self):
return len(self._queue) return self.redis.zcard('queue')
priority_queue = PriorityQueue() class DataEvent:
def __init__(self, event_id=None):
self.event_id = event_id if event_id else str(uuid4())
self.redis = Redis(host='localhost', port=6379, db=14)
self.pubsub = self.redis.pubsub()
self.pubsub.subscribe(self.event_id)
def set(self, data):
self.redis.publish(self.event_id, pickle.dumps(data))
def wait(self):
for item in self.pubsub.listen():
if item['type'] == 'message':
return pickle.loads(item['data'])
class DataEvent(threading.Event): priority_queue = RedisPriorityQueue()
def __init__(self):
super().__init__()
self.data = None
def worker(): def worker():
global processing_ips_lock
while True: while True:
(request_json_body, client_ip, token, parameters), event = priority_queue.get() index, (request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
# redis.sadd('processing_ips', client_ip)
increment_ip_count(client_ip, 'processing_ips') increment_ip_count(client_ip, 'processing_ips')
redis.incr('active_gen_workers') redis.incr('active_gen_workers')
@ -91,10 +114,9 @@ def worker():
with generation_elapsed_lock: with generation_elapsed_lock:
generation_elapsed.append((end_time, elapsed_time)) generation_elapsed.append((end_time, elapsed_time))
event.data = (success, response, error_msg) event = DataEvent(event_id)
event.set() event.set((success, response, error_msg))
# redis.srem('processing_ips', client_ip)
decrement_ip_count(client_ip, 'processing_ips') decrement_ip_count(client_ip, 'processing_ips')
redis.decr('active_gen_workers') redis.decr('active_gen_workers')

View File

@ -3,7 +3,7 @@ import time
from typing import Tuple, Union from typing import Tuple, Union
import flask import flask
from flask import Response, jsonify from flask import Response
from llm_server import opts from llm_server import opts
from llm_server.database import log_prompt from llm_server.database import log_prompt
@ -27,7 +27,7 @@ class RequestHandler:
self.token = self.request.headers.get('X-Api-Key') self.token = self.request.headers.get('X-Api-Key')
self.priority = self.get_priority() self.priority = self.get_priority()
self.backend = get_backend() self.backend = get_backend()
self.parameters = self.parameters_invalid_msg = None self.parameters = None
self.used = False self.used = False
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time() SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
@ -50,31 +50,26 @@ class RequestHandler:
return result[0] return result[0]
return DEFAULT_PRIORITY return DEFAULT_PRIORITY
def load_parameters(self): def get_parameters(self):
# Handle OpenAI
if self.request_json_body.get('max_tokens'): if self.request_json_body.get('max_tokens'):
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens') self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body) parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
return parameters, parameters_invalid_msg
def validate_request(self) -> Tuple[bool, Tuple[Response | None, int]]: def validate_request(self) -> Tuple[bool, Tuple[Response | None, int]]:
self.load_parameters() self.parameters, parameters_invalid_msg = self.get_parameters()
params_valid = False
request_valid = False request_valid = False
invalid_request_err_msg = None
if self.parameters: if self.parameters:
params_valid = True
request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters) request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters)
if not request_valid or not params_valid: if not request_valid:
error_messages = [msg for valid, msg in [request_valid, params_valid] if not valid and msg] error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (not bool(parameters_invalid_msg), parameters_invalid_msg)] if not valid and msg]
combined_error_message = ', '.join(error_messages) combined_error_message = ', '.join(error_messages)
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error') backend_response = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types # TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
return False, (jsonify({ return False, self.handle_error(backend_response)
'code': 400,
'msg': 'parameter validation error',
'results': [{'text': err}]
}), 200)
return True, (None, 0) return True, (None, 0)
def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]: def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]:
@ -88,9 +83,7 @@ class RequestHandler:
prompt = llm_request['prompt'] prompt = llm_request['prompt']
event.wait() success, response, error_msg = event.wait()
success, response, error_msg = event.data
end_time = time.time() end_time = time.time()
elapsed_time = end_time - self.start_time elapsed_time = end_time - self.start_time
@ -113,11 +106,7 @@ class RequestHandler:
error_msg = error_msg.strip('.') + '.' error_msg = error_msg.strip('.') + '.'
backend_response = format_sillytavern_err(error_msg, 'error') backend_response = format_sillytavern_err(error_msg, 'error')
log_prompt(self.client_ip, self.token, prompt, backend_response, None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) log_prompt(self.client_ip, self.token, prompt, backend_response, None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
return (False, None, None, 0), (jsonify({ return (False, None, None, 0), self.handle_error(backend_response)
'code': 500,
'msg': error_msg,
'results': [{'text': backend_response}]
}), 200)
# =============================================== # ===============================================
@ -137,11 +126,7 @@ class RequestHandler:
error_msg = 'The backend did not return valid JSON.' error_msg = 'The backend did not return valid JSON.'
backend_response = format_sillytavern_err(error_msg, 'error') backend_response = format_sillytavern_err(error_msg, 'error')
log_prompt(self.client_ip, self.token, prompt, backend_response, elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) log_prompt(self.client_ip, self.token, prompt, backend_response, elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
return (False, None, None, 0), (jsonify({ return (False, None, None, 0), self.handle_error(backend_response)
'code': 500,
'msg': error_msg,
'results': [{'text': backend_response}]
}), 200)
# =============================================== # ===============================================
@ -164,6 +149,9 @@ class RequestHandler:
def handle_ratelimited(self) -> Tuple[flask.Response, int]: def handle_ratelimited(self) -> Tuple[flask.Response, int]:
raise NotImplementedError raise NotImplementedError
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
raise NotImplementedError
def get_backend(): def get_backend():
if opts.mode == 'oobabooga': if opts.mode == 'oobabooga':

View File

@ -186,4 +186,4 @@ def server_error(e):
if __name__ == "__main__": if __name__ == "__main__":
app.run(host='0.0.0.0') app.run(host='0.0.0.0', threaded=False, processes=15)