add ratelimiting to websocket streaming endpoint, fix queue not decrementing IP requests, add console printer

This commit is contained in:
Cyberes 2023-09-27 21:15:54 -06:00
parent 43299b32ad
commit e5fbc9545d
8 changed files with 191 additions and 116 deletions

View File

@ -3,8 +3,9 @@ from llm_server import opts
def generator(request_json_body): def generator(request_json_body):
if opts.mode == 'oobabooga': if opts.mode == 'oobabooga':
from .oobabooga.generate import generate # from .oobabooga.generate import generate
return generate(request_json_body) # return generate(request_json_body)
raise NotImplementedError
elif opts.mode == 'vllm': elif opts.mode == 'vllm':
from .vllm.generate import generate from .vllm.generate import generate
r = generate(request_json_body) r = generate(request_json_body)

View File

@ -1,13 +1,11 @@
import json import json
import pickle import pickle
import threading
import time import time
from uuid import uuid4 from uuid import uuid4
from redis import Redis from redis import Redis
from llm_server import opts from llm_server import opts
from llm_server.llm.generator import generator
from llm_server.routes.cache import redis from llm_server.routes.cache import redis
@ -39,6 +37,8 @@ class RedisPriorityQueue:
# Check if the IP is already in the dictionary and if it has reached the limit # Check if the IP is already in the dictionary and if it has reached the limit
ip_count = self.redis.hget('queued_ip_count', item[1]) ip_count = self.redis.hget('queued_ip_count', item[1])
if ip_count:
ip_count = int(ip_count)
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0: if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.') print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
return None # reject the request return None # reject the request
@ -52,11 +52,10 @@ class RedisPriorityQueue:
data = self.redis.zpopmin('queue') data = self.redis.zpopmin('queue')
if data: if data:
item = json.loads(data[0][0]) item = json.loads(data[0][0])
client_ip = item[1][1] client_ip = item[0][1]
# Decrement the count for this IP
self.decrement_ip_count(client_ip, 'queued_ip_count') self.decrement_ip_count(client_ip, 'queued_ip_count')
return item return item
time.sleep(1) # wait for an item to be added to the queue time.sleep(0.5) # wait for something to be added to the queue
def increment_ip_count(self, ip, key): def increment_ip_count(self, ip, key):
self.redis.hincrby(key, ip, 1) self.redis.hincrby(key, ip, 1)
@ -67,6 +66,13 @@ class RedisPriorityQueue:
def __len__(self): def __len__(self):
return self.redis.zcard('queue') return self.redis.zcard('queue')
def get_ip_count(self, client_ip: str):
x = self.redis.hget('queued_ip_count', client_ip)
if x:
return x.decode('utf-8')
else:
return x
class DataEvent: class DataEvent:
def __init__(self, event_id=None): def __init__(self, event_id=None):
@ -85,32 +91,3 @@ class DataEvent:
priority_queue = RedisPriorityQueue() priority_queue = RedisPriorityQueue()
def worker():
while True:
(request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
increment_ip_count(client_ip, 'processing_ips')
redis.incr('active_gen_workers')
try:
start_time = time.time()
success, response, error_msg = generator(request_json_body)
end_time = time.time()
elapsed_time = end_time - start_time
redis.rpush('generation_elapsed', json.dumps((end_time, elapsed_time)))
event = DataEvent(event_id)
event.set((success, response, error_msg))
finally:
decrement_ip_count(client_ip, 'processing_ips')
redis.decr('active_gen_workers')
def start_workers(num_workers: int):
for _ in range(num_workers):
t = threading.Thread(target=worker)
t.daemon = True
t.start()

View File

@ -12,7 +12,6 @@ from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.auth import parse_token from llm_server.routes.auth import parse_token
from llm_server.routes.cache import redis from llm_server.routes.cache import redis
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.helpers.http import require_api_key, validate_json from llm_server.routes.helpers.http import require_api_key, validate_json
from llm_server.routes.queue import priority_queue from llm_server.routes.queue import priority_queue
@ -134,7 +133,6 @@ class RequestHandler:
request_valid, invalid_response = self.validate_request(prompt, do_log=True) request_valid, invalid_response = self.validate_request(prompt, do_log=True)
if not request_valid: if not request_valid:
return (False, None, None, 0), invalid_response return (False, None, None, 0), invalid_response
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.token_priority) event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.token_priority)
else: else:
event = None event = None
@ -193,14 +191,11 @@ class RequestHandler:
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers)) return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
def is_client_ratelimited(self) -> bool: def is_client_ratelimited(self) -> bool:
print('queued_ip_count', redis.get_dict('queued_ip_count'))
print('processing_ips', redis.get_dict('processing_ips'))
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0) queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
if queued_ip_count < self.token_simultaneous_ip or self.token_priority == 0: if queued_ip_count < self.token_simultaneous_ip or self.token_priority == 0:
return False return False
else: else:
print(f'Rejecting request from {self.client_ip} - {queued_ip_count} queued + processing.')
return True return True
def handle_request(self) -> Tuple[flask.Response, int]: def handle_request(self) -> Tuple[flask.Response, int]:

View File

@ -2,12 +2,14 @@ import json
import threading import threading
import time import time
import traceback import traceback
from typing import Union
from flask import request from flask import request
from ..helpers.client import format_sillytavern_err from ..cache import redis
from ..helpers.http import require_api_key, validate_json from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
from ..queue import decrement_ip_count, priority_queue
from ... import opts from ... import opts
from ...database.database import log_prompt from ...database.database import log_prompt
from ...llm.generator import generator from ...llm.generator import generator
@ -20,8 +22,31 @@ from ...stream import sock
@sock.route('/api/v1/stream') @sock.route('/api/v1/stream')
def stream(ws): def stream(ws):
def send_err_and_quit(quitting_err_msg):
ws.send(json.dumps({
'event': 'text_stream',
'message_num': 0,
'text': quitting_err_msg
}))
ws.send(json.dumps({
'event': 'stream_end',
'message_num': 1
}))
ws.close()
log_in_bg(quitting_err_msg, is_error=True)
def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None):
def background_task_exception():
generated_tokens = tokenize(generated_text_bg)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, response_tokens=generated_tokens, is_error=is_error)
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task_exception)
thread.start()
thread.join()
if not opts.enable_streaming: if not opts.enable_streaming:
return 'disabled', 401 return 'Streaming is disabled', 401
r_headers = dict(request.headers) r_headers = dict(request.headers)
r_url = request.url r_url = request.url
@ -30,12 +55,7 @@ def stream(ws):
message = ws.receive() message = ws.receive()
request_valid_json, request_json_body = validate_json(message) request_valid_json, request_json_body = validate_json(message)
if not request_valid_json or not request_json_body.get('prompt'): if not request_valid_json or not request_json_body.get('prompt'):
ws.send(json.dumps({ return 'Invalid JSON', 400
'event': 'text_stream',
'message_num': message_num,
'text': 'Invalid JSON'
}))
message_num += 1
else: else:
if opts.mode != 'vllm': if opts.mode != 'vllm':
# TODO: implement other backends # TODO: implement other backends
@ -50,36 +70,44 @@ def stream(ws):
input_prompt = request_json_body['prompt'] input_prompt = request_json_body['prompt']
response_status_code = 0 response_status_code = 0
start_time = time.time() start_time = time.time()
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
if not request_valid:
err_msg = invalid_response[0].json['results'][0]['text']
ws.send(json.dumps({
'event': 'text_stream',
'message_num': 0,
'text': err_msg
}))
ws.send(json.dumps({
'event': 'stream_end',
'message_num': 1
}))
ws.close() # this is important if we encountered and error and exited early.
def background_task(): err_msg = None
log_prompt(handler.client_ip, handler.token, input_prompt, err_msg, None, handler.parameters, r_headers, response_status_code, r_url, is_error=True) if handler.is_client_ratelimited():
r, _ = handler.handle_ratelimited()
# TODO: use async/await instead of threads err_msg = r.json['results'][0]['text']
thread = threading.Thread(target=background_task)
thread.start()
thread.join()
else: else:
msg_to_backend = { request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
**handler.parameters, if not request_valid:
'prompt': input_prompt, err_msg = invalid_response[0].json['results'][0]['text']
'stream': True, if err_msg:
} send_err_and_quit(err_msg)
try: return
response = generator(msg_to_backend)
llm_request = {
**handler.parameters,
'prompt': input_prompt,
'stream': True,
}
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None), handler.token_priority)
if not event:
r, _ = handler.handle_ratelimited()
err_msg = r.json['results'][0]['text']
send_err_and_quit(err_msg)
return
try:
response = generator(llm_request)
if not response:
error_msg = 'Failed to reach backend while streaming.'
print('Streaming failed:', error_msg)
msg = handler.handle_error(error_msg)[0].json['results'][0]['text']
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': msg
}))
else:
# Be extra careful when getting attributes from the response object # Be extra careful when getting attributes from the response object
try: try:
response_status_code = response.status_code response_status_code = response.status_code
@ -88,14 +116,6 @@ def stream(ws):
partial_response = b'' partial_response = b''
# TODO: handle when the backend is offline
# Traceback (most recent call last):
# File "/srv/server/local-llm-server/llm_server/routes/v1/generate_stream.py", line 91, in stream
# for chunk in response.iter_content(chunk_size=1):
# ^^^^^^^^^^^^^^^^^^^^^
# AttributeError: 'NoneType' object has no attribute 'iter_content'
for chunk in response.iter_content(chunk_size=1): for chunk in response.iter_content(chunk_size=1):
partial_response += chunk partial_response += chunk
if partial_response.endswith(b'\x00'): if partial_response.endswith(b'\x00'):
@ -116,6 +136,7 @@ def stream(ws):
'text': new 'text': new
})) }))
except: except:
# The client closed the stream.
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
@ -128,41 +149,32 @@ def stream(ws):
if not chunk: if not chunk:
break break
if response:
response.close() response.close()
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code)
def background_task_success(): except:
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) traceback.print_exc()
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text']
# TODO: use async/await instead of threads ws.send(json.dumps({
thread = threading.Thread(target=background_task_success) 'event': 'text_stream',
thread.start() 'message_num': message_num,
thread.join() 'text': generated_text
except: }))
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].data.decode('utf-8') log_in_bg(generated_text, is_error=True, status_code=response_status_code)
traceback.print_exc() finally:
ws.send(json.dumps({ # The worker incremented it, we'll decrement it.
'event': 'text_stream', decrement_ip_count(handler.client_ip, 'processing_ips')
'message_num': message_num, redis.decr('active_gen_workers')
'text': generated_text
}))
def background_task_exception():
generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task_exception)
thread.start()
thread.join()
try: try:
ws.send(json.dumps({ ws.send(json.dumps({
'event': 'stream_end', 'event': 'stream_end',
'message_num': message_num 'message_num': message_num
})) }))
except: except:
# The client closed the stream.
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))

View File

View File

@ -0,0 +1,58 @@
import json
import threading
import time
from llm_server import opts
from llm_server.llm.generator import generator
from llm_server.routes.cache import redis
from llm_server.routes.queue import DataEvent, decrement_ip_count, increment_ip_count, priority_queue
def worker():
while True:
need_to_wait()
(request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
need_to_wait()
increment_ip_count(client_ip, 'processing_ips')
redis.incr('active_gen_workers')
if not request_json_body:
# This was a dummy request from the websocket handler.
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
continue
try:
start_time = time.time()
success, response, error_msg = generator(request_json_body)
end_time = time.time()
elapsed_time = end_time - start_time
# redis.rpush('generation_elapsed', json.dumps((end_time, elapsed_time)))
event = DataEvent(event_id)
event.set((success, response, error_msg))
finally:
decrement_ip_count(client_ip, 'processing_ips')
redis.decr('active_gen_workers')
def start_workers(num_workers: int):
i = 0
for _ in range(num_workers):
t = threading.Thread(target=worker)
t.daemon = True
t.start()
i += 1
print(f'Started {i} inference workers.')
def need_to_wait():
# We need to check the number of active workers since the streaming endpoint may be doing something.
active_workers = redis.get('active_gen_workers', int, 0)
s = time.time()
while active_workers >= opts.concurrent_gens:
time.sleep(0.01)
e = time.time()
if e - s > 0.5:
print(f'Worker was delayed {e - s} seconds.')

View File

@ -0,0 +1,29 @@
import logging
import threading
import time
from llm_server.routes.cache import redis
def console_printer():
logger = logging.getLogger('console_printer')
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s: %(levelname)s:%(name)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
while True:
queued_ip_count = redis.get_dict('queued_ip_count')
queued_ip_count = sum([v for k, v in queued_ip_count.items()])
processing_ips = redis.get_dict('processing_ips')
processing_count = sum([v for k, v in processing_ips.items()])
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {queued_ip_count}')
time.sleep(10)
def start_console_printer():
t = threading.Thread(target=console_printer)
t.daemon = True
t.start()

View File

@ -1,4 +1,4 @@
from redis import Redis from llm_server.workers.printer import start_console_printer
try: try:
import gevent.monkey import gevent.monkey
@ -16,6 +16,7 @@ from threading import Thread
import openai import openai
import simplejson as json import simplejson as json
from flask import Flask, jsonify, render_template, request from flask import Flask, jsonify, render_template, request
from redis import Redis
import llm_server import llm_server
from llm_server.database.conn import database from llm_server.database.conn import database
@ -26,6 +27,7 @@ from llm_server.routes.openai import openai_bp
from llm_server.routes.server_error import handle_server_error from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp from llm_server.routes.v1 import bp
from llm_server.stream import init_socketio from llm_server.stream import init_socketio
from llm_server.workers.blocking import start_workers
# TODO: have the workers handle streaming too # TODO: have the workers handle streaming too
# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail # TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail
@ -36,6 +38,7 @@ from llm_server.stream import init_socketio
# TODO: implement RRD backend loadbalancer option # TODO: implement RRD backend loadbalancer option
# TODO: have VLLM reject a request if it already has n == concurrent_gens running # TODO: have VLLM reject a request if it already has n == concurrent_gens running
# TODO: add a way to cancel VLLM gens. Maybe use websockets? # TODO: add a way to cancel VLLM gens. Maybe use websockets?
# TODO: use coloredlogs
# Lower priority # Lower priority
# TODO: the processing stat showed -1 and I had to restart the server # TODO: the processing stat showed -1 and I had to restart the server
@ -65,7 +68,6 @@ from llm_server.helpers import resolve_path, auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info from llm_server.llm.vllm.info import vllm_info
from llm_server.routes.cache import RedisWrapper, flask_cache from llm_server.routes.cache import RedisWrapper, flask_cache
from llm_server.llm import redis from llm_server.llm import redis
from llm_server.routes.queue import start_workers
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers
from llm_server.routes.v1.generate_stats import generate_stats from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers
@ -166,6 +168,8 @@ def pre_fork(server):
redis.set_dict('processing_ips', {}) redis.set_dict('processing_ips', {})
redis.set_dict('queued_ip_count', {}) redis.set_dict('queued_ip_count', {})
# Flush the RedisPriorityQueue database.
queue_redis = Redis(host='localhost', port=6379, db=15) queue_redis = Redis(host='localhost', port=6379, db=15)
for key in queue_redis.scan_iter('*'): for key in queue_redis.scan_iter('*'):
queue_redis.delete(key) queue_redis.delete(key)
@ -181,8 +185,7 @@ def pre_fork(server):
# Start background processes # Start background processes
start_workers(opts.concurrent_gens) start_workers(opts.concurrent_gens)
print(f'Started {opts.concurrent_gens} inference workers.') start_console_printer()
start_moderation_workers(opts.openai_moderation_workers) start_moderation_workers(opts.openai_moderation_workers)
MainBackgroundThread().start() MainBackgroundThread().start()
SemaphoreCheckerThread().start() SemaphoreCheckerThread().start()