set up queue to work with gunicorn processes, other improvements
This commit is contained in:
parent
5d03f875cb
commit
3100b0a924
|
@ -4,7 +4,7 @@ import flask
|
||||||
|
|
||||||
|
|
||||||
class LLMBackend:
|
class LLMBackend:
|
||||||
default_params: dict
|
_default_params: dict
|
||||||
|
|
||||||
def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers):
|
def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -10,7 +10,7 @@ from llm_server.routes.helpers.http import validate_json
|
||||||
|
|
||||||
|
|
||||||
class VLLMBackend(LLMBackend):
|
class VLLMBackend(LLMBackend):
|
||||||
default_params = vars(SamplingParams())
|
_default_params = vars(SamplingParams())
|
||||||
|
|
||||||
def handle_response(self, success, request, response_json_body, response_status_code, client_ip, token, prompt: str, elapsed_time, parameters, headers):
|
def handle_response(self, success, request, response_json_body, response_status_code, client_ip, token, prompt: str, elapsed_time, parameters, headers):
|
||||||
if len(response_json_body.get('text', [])):
|
if len(response_json_body.get('text', [])):
|
||||||
|
@ -25,14 +25,18 @@ class VLLMBackend(LLMBackend):
|
||||||
|
|
||||||
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
|
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
|
||||||
try:
|
try:
|
||||||
|
# top_k == -1 means disabled
|
||||||
|
top_k = parameters.get('top_k', self._default_params['top_k'])
|
||||||
|
if top_k <= 0:
|
||||||
|
top_k = -1
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=parameters.get('temperature', self.default_params['temperature']),
|
temperature=parameters.get('temperature', self._default_params['temperature']),
|
||||||
top_p=parameters.get('top_p', self.default_params['top_p']),
|
top_p=parameters.get('top_p', self._default_params['top_p']),
|
||||||
top_k=parameters.get('top_k', self.default_params['top_k']),
|
top_k=top_k,
|
||||||
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
|
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
|
||||||
stop=parameters.get('stopping_strings', self.default_params['stop']),
|
stop=parameters.get('stopping_strings', self._default_params['stop']),
|
||||||
ignore_eos=parameters.get('ban_eos_token', False),
|
ignore_eos=parameters.get('ban_eos_token', False),
|
||||||
max_tokens=parameters.get('max_new_tokens', self.default_params['max_tokens'])
|
max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens'])
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return None, str(e).strip('.')
|
return None, str(e).strip('.')
|
||||||
|
|
|
@ -30,4 +30,4 @@ expose_openai_system_prompt = True
|
||||||
enable_streaming = True
|
enable_streaming = True
|
||||||
openai_api_key = None
|
openai_api_key = None
|
||||||
backend_request_timeout = 30
|
backend_request_timeout = 30
|
||||||
backend_generate_request_timeout = 120
|
backend_generate_request_timeout = 95
|
||||||
|
|
|
@ -1,12 +1,11 @@
|
||||||
import time
|
from typing import Tuple
|
||||||
|
|
||||||
|
import flask
|
||||||
from flask import jsonify
|
from flask import jsonify
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.database import log_prompt
|
from llm_server.database import log_prompt
|
||||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||||
from llm_server.routes.helpers.http import validate_json
|
|
||||||
from llm_server.routes.queue import priority_queue
|
|
||||||
from llm_server.routes.request_handler import RequestHandler
|
from llm_server.routes.request_handler import RequestHandler
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,3 +34,8 @@ class OobaRequestHandler(RequestHandler):
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'results': [{'text': backend_response}]
|
'results': [{'text': backend_response}]
|
||||||
}), 200
|
}), 200
|
||||||
|
|
||||||
|
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
|
||||||
|
return jsonify({
|
||||||
|
'results': [{'text': msg}]
|
||||||
|
}), 200
|
||||||
|
|
|
@ -10,7 +10,6 @@ from ..openai_request_handler import OpenAIRequestHandler, build_openai_response
|
||||||
|
|
||||||
@openai_bp.route('/chat/completions', methods=['POST'])
|
@openai_bp.route('/chat/completions', methods=['POST'])
|
||||||
def openai_chat_completions():
|
def openai_chat_completions():
|
||||||
# TODO: make this work with oobabooga
|
|
||||||
request_valid_json, request_json_body = validate_json(request)
|
request_valid_json, request_json_body = validate_json(request)
|
||||||
if not request_valid_json or not request_json_body.get('messages'):
|
if not request_valid_json or not request_json_body.get('messages'):
|
||||||
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
||||||
|
|
|
@ -48,9 +48,11 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
# Reconstruct the request JSON with the validated parameters and prompt.
|
# Reconstruct the request JSON with the validated parameters and prompt.
|
||||||
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
|
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
|
||||||
llm_request = {**self.parameters, 'prompt': self.prompt}
|
llm_request = {**self.parameters, 'prompt': self.prompt}
|
||||||
|
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
||||||
_, (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
if success:
|
||||||
return build_openai_response(self.prompt, backend_response.json['results'][0]['text']), backend_response_status_code
|
return build_openai_response(self.prompt, backend_response.json['results'][0]['text']), backend_response_status_code
|
||||||
|
else:
|
||||||
|
return backend_response, backend_response_status_code
|
||||||
|
|
||||||
def handle_ratelimited(self):
|
def handle_ratelimited(self):
|
||||||
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
|
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
|
||||||
|
@ -81,13 +83,16 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
prompt += '\n\n### RESPONSE: '
|
prompt += '\n\n### RESPONSE: '
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
|
||||||
|
return build_openai_response('', msg), 200
|
||||||
|
|
||||||
|
|
||||||
def check_moderation_endpoint(prompt: str):
|
def check_moderation_endpoint(prompt: str):
|
||||||
headers = {
|
headers = {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
'Authorization': f"Bearer {opts.openai_api_key}",
|
'Authorization': f"Bearer {opts.openai_api_key}",
|
||||||
}
|
}
|
||||||
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}).json()
|
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10).json()
|
||||||
offending_categories = []
|
offending_categories = []
|
||||||
for k, v in response['results'][0]['categories'].items():
|
for k, v in response['results'][0]['categories'].items():
|
||||||
if v:
|
if v:
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
import heapq
|
import json
|
||||||
|
import pickle
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from redis import Redis
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.llm.generator import generator
|
from llm_server.llm.generator import generator
|
||||||
|
@ -27,58 +31,77 @@ def decrement_ip_count(client_ip: int, redis_key):
|
||||||
return ip_count
|
return ip_count
|
||||||
|
|
||||||
|
|
||||||
class PriorityQueue:
|
class RedisPriorityQueue:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._queue = []
|
|
||||||
self._index = 0
|
self._index = 0
|
||||||
self._cv = threading.Condition()
|
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
redis.set_dict('queued_ip_count', {})
|
self.redis = Redis(host='localhost', port=6379, db=15)
|
||||||
|
|
||||||
|
# Clear the DB
|
||||||
|
for key in self.redis.scan_iter('*'):
|
||||||
|
self.redis.delete(key)
|
||||||
|
|
||||||
|
self.pubsub = self.redis.pubsub()
|
||||||
|
self.pubsub.subscribe('events')
|
||||||
|
|
||||||
def put(self, item, priority):
|
def put(self, item, priority):
|
||||||
event = DataEvent()
|
event = DataEvent()
|
||||||
with self._cv:
|
# Check if the IP is already in the dictionary and if it has reached the limit
|
||||||
# Check if the IP is already in the dictionary and if it has reached the limit
|
ip_count = self.redis.hget('queued_ip_count', item[1])
|
||||||
ip_count = redis.get_dict('queued_ip_count')
|
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
|
||||||
if item[1] in ip_count and ip_count[item[1]] >= opts.simultaneous_requests_per_ip and priority != 0:
|
return None # reject the request
|
||||||
return None # reject the request
|
self.redis.zadd('queue', {json.dumps((self._index, item, event.event_id)): -priority})
|
||||||
heapq.heappush(self._queue, (-priority, self._index, item, event))
|
self._index += 1
|
||||||
self._index += 1
|
# Increment the count for this IP
|
||||||
# Increment the count for this IP
|
with self._lock:
|
||||||
with self._lock:
|
self.increment_ip_count(item[1], 'queued_ip_count')
|
||||||
increment_ip_count(item[1], 'queued_ip_count')
|
|
||||||
self._cv.notify()
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
with self._cv:
|
while True:
|
||||||
while len(self._queue) == 0:
|
data = self.redis.zpopmin('queue')
|
||||||
self._cv.wait()
|
if data:
|
||||||
_, _, item, event = heapq.heappop(self._queue)
|
item = json.loads(data[0][0])
|
||||||
# Decrement the count for this IP
|
client_ip = item[1][1]
|
||||||
with self._lock:
|
# Decrement the count for this IP
|
||||||
decrement_ip_count(item[1], 'queued_ip_count')
|
with self._lock:
|
||||||
return item, event
|
self.decrement_ip_count(client_ip, 'queued_ip_count')
|
||||||
|
return item
|
||||||
|
time.sleep(1) # wait for an item to be added to the queue
|
||||||
|
|
||||||
|
def increment_ip_count(self, ip, key):
|
||||||
|
self.redis.hincrby(key, ip, 1)
|
||||||
|
|
||||||
|
def decrement_ip_count(self, ip, key):
|
||||||
|
self.redis.hincrby(key, ip, -1)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._queue)
|
return self.redis.zcard('queue')
|
||||||
|
|
||||||
|
|
||||||
priority_queue = PriorityQueue()
|
class DataEvent:
|
||||||
|
def __init__(self, event_id=None):
|
||||||
|
self.event_id = event_id if event_id else str(uuid4())
|
||||||
|
self.redis = Redis(host='localhost', port=6379, db=14)
|
||||||
|
self.pubsub = self.redis.pubsub()
|
||||||
|
self.pubsub.subscribe(self.event_id)
|
||||||
|
|
||||||
|
def set(self, data):
|
||||||
|
self.redis.publish(self.event_id, pickle.dumps(data))
|
||||||
|
|
||||||
|
def wait(self):
|
||||||
|
for item in self.pubsub.listen():
|
||||||
|
if item['type'] == 'message':
|
||||||
|
return pickle.loads(item['data'])
|
||||||
|
|
||||||
|
|
||||||
class DataEvent(threading.Event):
|
priority_queue = RedisPriorityQueue()
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.data = None
|
|
||||||
|
|
||||||
|
|
||||||
def worker():
|
def worker():
|
||||||
global processing_ips_lock
|
|
||||||
while True:
|
while True:
|
||||||
(request_json_body, client_ip, token, parameters), event = priority_queue.get()
|
index, (request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
|
||||||
|
|
||||||
# redis.sadd('processing_ips', client_ip)
|
|
||||||
increment_ip_count(client_ip, 'processing_ips')
|
increment_ip_count(client_ip, 'processing_ips')
|
||||||
|
|
||||||
redis.incr('active_gen_workers')
|
redis.incr('active_gen_workers')
|
||||||
|
@ -91,10 +114,9 @@ def worker():
|
||||||
with generation_elapsed_lock:
|
with generation_elapsed_lock:
|
||||||
generation_elapsed.append((end_time, elapsed_time))
|
generation_elapsed.append((end_time, elapsed_time))
|
||||||
|
|
||||||
event.data = (success, response, error_msg)
|
event = DataEvent(event_id)
|
||||||
event.set()
|
event.set((success, response, error_msg))
|
||||||
|
|
||||||
# redis.srem('processing_ips', client_ip)
|
|
||||||
decrement_ip_count(client_ip, 'processing_ips')
|
decrement_ip_count(client_ip, 'processing_ips')
|
||||||
redis.decr('active_gen_workers')
|
redis.decr('active_gen_workers')
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import time
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import flask
|
import flask
|
||||||
from flask import Response, jsonify
|
from flask import Response
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.database import log_prompt
|
from llm_server.database import log_prompt
|
||||||
|
@ -27,7 +27,7 @@ class RequestHandler:
|
||||||
self.token = self.request.headers.get('X-Api-Key')
|
self.token = self.request.headers.get('X-Api-Key')
|
||||||
self.priority = self.get_priority()
|
self.priority = self.get_priority()
|
||||||
self.backend = get_backend()
|
self.backend = get_backend()
|
||||||
self.parameters = self.parameters_invalid_msg = None
|
self.parameters = None
|
||||||
self.used = False
|
self.used = False
|
||||||
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
|
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
|
||||||
|
|
||||||
|
@ -50,31 +50,26 @@ class RequestHandler:
|
||||||
return result[0]
|
return result[0]
|
||||||
return DEFAULT_PRIORITY
|
return DEFAULT_PRIORITY
|
||||||
|
|
||||||
def load_parameters(self):
|
def get_parameters(self):
|
||||||
# Handle OpenAI
|
|
||||||
if self.request_json_body.get('max_tokens'):
|
if self.request_json_body.get('max_tokens'):
|
||||||
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
|
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
|
||||||
self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
|
parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
|
||||||
|
return parameters, parameters_invalid_msg
|
||||||
|
|
||||||
def validate_request(self) -> Tuple[bool, Tuple[Response | None, int]]:
|
def validate_request(self) -> Tuple[bool, Tuple[Response | None, int]]:
|
||||||
self.load_parameters()
|
self.parameters, parameters_invalid_msg = self.get_parameters()
|
||||||
params_valid = False
|
|
||||||
request_valid = False
|
request_valid = False
|
||||||
|
invalid_request_err_msg = None
|
||||||
if self.parameters:
|
if self.parameters:
|
||||||
params_valid = True
|
|
||||||
request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters)
|
request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters)
|
||||||
|
|
||||||
if not request_valid or not params_valid:
|
if not request_valid:
|
||||||
error_messages = [msg for valid, msg in [request_valid, params_valid] if not valid and msg]
|
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (not bool(parameters_invalid_msg), parameters_invalid_msg)] if not valid and msg]
|
||||||
combined_error_message = ', '.join(error_messages)
|
combined_error_message = ', '.join(error_messages)
|
||||||
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
backend_response = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
||||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
||||||
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
|
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
|
||||||
return False, (jsonify({
|
return False, self.handle_error(backend_response)
|
||||||
'code': 400,
|
|
||||||
'msg': 'parameter validation error',
|
|
||||||
'results': [{'text': err}]
|
|
||||||
}), 200)
|
|
||||||
return True, (None, 0)
|
return True, (None, 0)
|
||||||
|
|
||||||
def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]:
|
def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]:
|
||||||
|
@ -88,9 +83,7 @@ class RequestHandler:
|
||||||
|
|
||||||
prompt = llm_request['prompt']
|
prompt = llm_request['prompt']
|
||||||
|
|
||||||
event.wait()
|
success, response, error_msg = event.wait()
|
||||||
success, response, error_msg = event.data
|
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_time = end_time - self.start_time
|
elapsed_time = end_time - self.start_time
|
||||||
|
|
||||||
|
@ -113,11 +106,7 @@ class RequestHandler:
|
||||||
error_msg = error_msg.strip('.') + '.'
|
error_msg = error_msg.strip('.') + '.'
|
||||||
backend_response = format_sillytavern_err(error_msg, 'error')
|
backend_response = format_sillytavern_err(error_msg, 'error')
|
||||||
log_prompt(self.client_ip, self.token, prompt, backend_response, None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
|
log_prompt(self.client_ip, self.token, prompt, backend_response, None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
|
||||||
return (False, None, None, 0), (jsonify({
|
return (False, None, None, 0), self.handle_error(backend_response)
|
||||||
'code': 500,
|
|
||||||
'msg': error_msg,
|
|
||||||
'results': [{'text': backend_response}]
|
|
||||||
}), 200)
|
|
||||||
|
|
||||||
# ===============================================
|
# ===============================================
|
||||||
|
|
||||||
|
@ -137,11 +126,7 @@ class RequestHandler:
|
||||||
error_msg = 'The backend did not return valid JSON.'
|
error_msg = 'The backend did not return valid JSON.'
|
||||||
backend_response = format_sillytavern_err(error_msg, 'error')
|
backend_response = format_sillytavern_err(error_msg, 'error')
|
||||||
log_prompt(self.client_ip, self.token, prompt, backend_response, elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
|
log_prompt(self.client_ip, self.token, prompt, backend_response, elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
|
||||||
return (False, None, None, 0), (jsonify({
|
return (False, None, None, 0), self.handle_error(backend_response)
|
||||||
'code': 500,
|
|
||||||
'msg': error_msg,
|
|
||||||
'results': [{'text': backend_response}]
|
|
||||||
}), 200)
|
|
||||||
|
|
||||||
# ===============================================
|
# ===============================================
|
||||||
|
|
||||||
|
@ -164,6 +149,9 @@ class RequestHandler:
|
||||||
def handle_ratelimited(self) -> Tuple[flask.Response, int]:
|
def handle_ratelimited(self) -> Tuple[flask.Response, int]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def get_backend():
|
def get_backend():
|
||||||
if opts.mode == 'oobabooga':
|
if opts.mode == 'oobabooga':
|
||||||
|
|
Reference in New Issue