set up queue to work with gunicorn processes, other improvements
This commit is contained in:
parent
5d03f875cb
commit
3100b0a924
|
@ -4,7 +4,7 @@ import flask
|
|||
|
||||
|
||||
class LLMBackend:
|
||||
default_params: dict
|
||||
_default_params: dict
|
||||
|
||||
def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -10,7 +10,7 @@ from llm_server.routes.helpers.http import validate_json
|
|||
|
||||
|
||||
class VLLMBackend(LLMBackend):
|
||||
default_params = vars(SamplingParams())
|
||||
_default_params = vars(SamplingParams())
|
||||
|
||||
def handle_response(self, success, request, response_json_body, response_status_code, client_ip, token, prompt: str, elapsed_time, parameters, headers):
|
||||
if len(response_json_body.get('text', [])):
|
||||
|
@ -25,14 +25,18 @@ class VLLMBackend(LLMBackend):
|
|||
|
||||
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
|
||||
try:
|
||||
# top_k == -1 means disabled
|
||||
top_k = parameters.get('top_k', self._default_params['top_k'])
|
||||
if top_k <= 0:
|
||||
top_k = -1
|
||||
sampling_params = SamplingParams(
|
||||
temperature=parameters.get('temperature', self.default_params['temperature']),
|
||||
top_p=parameters.get('top_p', self.default_params['top_p']),
|
||||
top_k=parameters.get('top_k', self.default_params['top_k']),
|
||||
temperature=parameters.get('temperature', self._default_params['temperature']),
|
||||
top_p=parameters.get('top_p', self._default_params['top_p']),
|
||||
top_k=top_k,
|
||||
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
|
||||
stop=parameters.get('stopping_strings', self.default_params['stop']),
|
||||
stop=parameters.get('stopping_strings', self._default_params['stop']),
|
||||
ignore_eos=parameters.get('ban_eos_token', False),
|
||||
max_tokens=parameters.get('max_new_tokens', self.default_params['max_tokens'])
|
||||
max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens'])
|
||||
)
|
||||
except ValueError as e:
|
||||
return None, str(e).strip('.')
|
||||
|
|
|
@ -30,4 +30,4 @@ expose_openai_system_prompt = True
|
|||
enable_streaming = True
|
||||
openai_api_key = None
|
||||
backend_request_timeout = 30
|
||||
backend_generate_request_timeout = 120
|
||||
backend_generate_request_timeout = 95
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import flask
|
||||
from flask import jsonify
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import log_prompt
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.helpers.http import validate_json
|
||||
from llm_server.routes.queue import priority_queue
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
|
||||
|
||||
|
@ -35,3 +34,8 @@ class OobaRequestHandler(RequestHandler):
|
|||
return jsonify({
|
||||
'results': [{'text': backend_response}]
|
||||
}), 200
|
||||
|
||||
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
|
||||
return jsonify({
|
||||
'results': [{'text': msg}]
|
||||
}), 200
|
||||
|
|
|
@ -10,7 +10,6 @@ from ..openai_request_handler import OpenAIRequestHandler, build_openai_response
|
|||
|
||||
@openai_bp.route('/chat/completions', methods=['POST'])
|
||||
def openai_chat_completions():
|
||||
# TODO: make this work with oobabooga
|
||||
request_valid_json, request_json_body = validate_json(request)
|
||||
if not request_valid_json or not request_json_body.get('messages'):
|
||||
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
||||
|
|
|
@ -48,9 +48,11 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
# Reconstruct the request JSON with the validated parameters and prompt.
|
||||
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
|
||||
llm_request = {**self.parameters, 'prompt': self.prompt}
|
||||
|
||||
_, (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
||||
return build_openai_response(self.prompt, backend_response.json['results'][0]['text']), backend_response_status_code
|
||||
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
||||
if success:
|
||||
return build_openai_response(self.prompt, backend_response.json['results'][0]['text']), backend_response_status_code
|
||||
else:
|
||||
return backend_response, backend_response_status_code
|
||||
|
||||
def handle_ratelimited(self):
|
||||
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
|
||||
|
@ -81,13 +83,16 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
prompt += '\n\n### RESPONSE: '
|
||||
return prompt
|
||||
|
||||
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
|
||||
return build_openai_response('', msg), 200
|
||||
|
||||
|
||||
def check_moderation_endpoint(prompt: str):
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {opts.openai_api_key}",
|
||||
}
|
||||
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}).json()
|
||||
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10).json()
|
||||
offending_categories = []
|
||||
for k, v in response['results'][0]['categories'].items():
|
||||
if v:
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
import heapq
|
||||
import json
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
from redis import Redis
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.llm.generator import generator
|
||||
|
@ -27,58 +31,77 @@ def decrement_ip_count(client_ip: int, redis_key):
|
|||
return ip_count
|
||||
|
||||
|
||||
class PriorityQueue:
|
||||
class RedisPriorityQueue:
|
||||
def __init__(self):
|
||||
self._queue = []
|
||||
self._index = 0
|
||||
self._cv = threading.Condition()
|
||||
self._lock = threading.Lock()
|
||||
redis.set_dict('queued_ip_count', {})
|
||||
self.redis = Redis(host='localhost', port=6379, db=15)
|
||||
|
||||
# Clear the DB
|
||||
for key in self.redis.scan_iter('*'):
|
||||
self.redis.delete(key)
|
||||
|
||||
self.pubsub = self.redis.pubsub()
|
||||
self.pubsub.subscribe('events')
|
||||
|
||||
def put(self, item, priority):
|
||||
event = DataEvent()
|
||||
with self._cv:
|
||||
# Check if the IP is already in the dictionary and if it has reached the limit
|
||||
ip_count = redis.get_dict('queued_ip_count')
|
||||
if item[1] in ip_count and ip_count[item[1]] >= opts.simultaneous_requests_per_ip and priority != 0:
|
||||
return None # reject the request
|
||||
heapq.heappush(self._queue, (-priority, self._index, item, event))
|
||||
self._index += 1
|
||||
# Increment the count for this IP
|
||||
with self._lock:
|
||||
increment_ip_count(item[1], 'queued_ip_count')
|
||||
self._cv.notify()
|
||||
# Check if the IP is already in the dictionary and if it has reached the limit
|
||||
ip_count = self.redis.hget('queued_ip_count', item[1])
|
||||
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
|
||||
return None # reject the request
|
||||
self.redis.zadd('queue', {json.dumps((self._index, item, event.event_id)): -priority})
|
||||
self._index += 1
|
||||
# Increment the count for this IP
|
||||
with self._lock:
|
||||
self.increment_ip_count(item[1], 'queued_ip_count')
|
||||
return event
|
||||
|
||||
def get(self):
|
||||
with self._cv:
|
||||
while len(self._queue) == 0:
|
||||
self._cv.wait()
|
||||
_, _, item, event = heapq.heappop(self._queue)
|
||||
# Decrement the count for this IP
|
||||
with self._lock:
|
||||
decrement_ip_count(item[1], 'queued_ip_count')
|
||||
return item, event
|
||||
while True:
|
||||
data = self.redis.zpopmin('queue')
|
||||
if data:
|
||||
item = json.loads(data[0][0])
|
||||
client_ip = item[1][1]
|
||||
# Decrement the count for this IP
|
||||
with self._lock:
|
||||
self.decrement_ip_count(client_ip, 'queued_ip_count')
|
||||
return item
|
||||
time.sleep(1) # wait for an item to be added to the queue
|
||||
|
||||
def increment_ip_count(self, ip, key):
|
||||
self.redis.hincrby(key, ip, 1)
|
||||
|
||||
def decrement_ip_count(self, ip, key):
|
||||
self.redis.hincrby(key, ip, -1)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._queue)
|
||||
return self.redis.zcard('queue')
|
||||
|
||||
|
||||
priority_queue = PriorityQueue()
|
||||
class DataEvent:
|
||||
def __init__(self, event_id=None):
|
||||
self.event_id = event_id if event_id else str(uuid4())
|
||||
self.redis = Redis(host='localhost', port=6379, db=14)
|
||||
self.pubsub = self.redis.pubsub()
|
||||
self.pubsub.subscribe(self.event_id)
|
||||
|
||||
def set(self, data):
|
||||
self.redis.publish(self.event_id, pickle.dumps(data))
|
||||
|
||||
def wait(self):
|
||||
for item in self.pubsub.listen():
|
||||
if item['type'] == 'message':
|
||||
return pickle.loads(item['data'])
|
||||
|
||||
|
||||
class DataEvent(threading.Event):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.data = None
|
||||
priority_queue = RedisPriorityQueue()
|
||||
|
||||
|
||||
def worker():
|
||||
global processing_ips_lock
|
||||
while True:
|
||||
(request_json_body, client_ip, token, parameters), event = priority_queue.get()
|
||||
index, (request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
|
||||
|
||||
# redis.sadd('processing_ips', client_ip)
|
||||
increment_ip_count(client_ip, 'processing_ips')
|
||||
|
||||
redis.incr('active_gen_workers')
|
||||
|
@ -91,10 +114,9 @@ def worker():
|
|||
with generation_elapsed_lock:
|
||||
generation_elapsed.append((end_time, elapsed_time))
|
||||
|
||||
event.data = (success, response, error_msg)
|
||||
event.set()
|
||||
event = DataEvent(event_id)
|
||||
event.set((success, response, error_msg))
|
||||
|
||||
# redis.srem('processing_ips', client_ip)
|
||||
decrement_ip_count(client_ip, 'processing_ips')
|
||||
redis.decr('active_gen_workers')
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import time
|
|||
from typing import Tuple, Union
|
||||
|
||||
import flask
|
||||
from flask import Response, jsonify
|
||||
from flask import Response
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import log_prompt
|
||||
|
@ -27,7 +27,7 @@ class RequestHandler:
|
|||
self.token = self.request.headers.get('X-Api-Key')
|
||||
self.priority = self.get_priority()
|
||||
self.backend = get_backend()
|
||||
self.parameters = self.parameters_invalid_msg = None
|
||||
self.parameters = None
|
||||
self.used = False
|
||||
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
|
||||
|
||||
|
@ -50,31 +50,26 @@ class RequestHandler:
|
|||
return result[0]
|
||||
return DEFAULT_PRIORITY
|
||||
|
||||
def load_parameters(self):
|
||||
# Handle OpenAI
|
||||
def get_parameters(self):
|
||||
if self.request_json_body.get('max_tokens'):
|
||||
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
|
||||
self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
|
||||
parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
|
||||
return parameters, parameters_invalid_msg
|
||||
|
||||
def validate_request(self) -> Tuple[bool, Tuple[Response | None, int]]:
|
||||
self.load_parameters()
|
||||
params_valid = False
|
||||
self.parameters, parameters_invalid_msg = self.get_parameters()
|
||||
request_valid = False
|
||||
invalid_request_err_msg = None
|
||||
if self.parameters:
|
||||
params_valid = True
|
||||
request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters)
|
||||
|
||||
if not request_valid or not params_valid:
|
||||
error_messages = [msg for valid, msg in [request_valid, params_valid] if not valid and msg]
|
||||
if not request_valid:
|
||||
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (not bool(parameters_invalid_msg), parameters_invalid_msg)] if not valid and msg]
|
||||
combined_error_message = ', '.join(error_messages)
|
||||
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
||||
backend_response = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
||||
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
|
||||
return False, (jsonify({
|
||||
'code': 400,
|
||||
'msg': 'parameter validation error',
|
||||
'results': [{'text': err}]
|
||||
}), 200)
|
||||
return False, self.handle_error(backend_response)
|
||||
return True, (None, 0)
|
||||
|
||||
def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]:
|
||||
|
@ -88,9 +83,7 @@ class RequestHandler:
|
|||
|
||||
prompt = llm_request['prompt']
|
||||
|
||||
event.wait()
|
||||
success, response, error_msg = event.data
|
||||
|
||||
success, response, error_msg = event.wait()
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - self.start_time
|
||||
|
||||
|
@ -113,11 +106,7 @@ class RequestHandler:
|
|||
error_msg = error_msg.strip('.') + '.'
|
||||
backend_response = format_sillytavern_err(error_msg, 'error')
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response, None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
|
||||
return (False, None, None, 0), (jsonify({
|
||||
'code': 500,
|
||||
'msg': error_msg,
|
||||
'results': [{'text': backend_response}]
|
||||
}), 200)
|
||||
return (False, None, None, 0), self.handle_error(backend_response)
|
||||
|
||||
# ===============================================
|
||||
|
||||
|
@ -137,11 +126,7 @@ class RequestHandler:
|
|||
error_msg = 'The backend did not return valid JSON.'
|
||||
backend_response = format_sillytavern_err(error_msg, 'error')
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response, elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
|
||||
return (False, None, None, 0), (jsonify({
|
||||
'code': 500,
|
||||
'msg': error_msg,
|
||||
'results': [{'text': backend_response}]
|
||||
}), 200)
|
||||
return (False, None, None, 0), self.handle_error(backend_response)
|
||||
|
||||
# ===============================================
|
||||
|
||||
|
@ -164,6 +149,9 @@ class RequestHandler:
|
|||
def handle_ratelimited(self) -> Tuple[flask.Response, int]:
|
||||
raise NotImplementedError
|
||||
|
||||
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_backend():
|
||||
if opts.mode == 'oobabooga':
|
||||
|
|
Reference in New Issue