limit amount of simultaneous requests an IP can make

This commit is contained in:
Cyberes 2023-08-27 23:48:10 -06:00
parent 1a4cb5f786
commit c16d70a24d
7 changed files with 61 additions and 4 deletions

View File

@ -15,6 +15,11 @@ concurrent_gens: 1
# This number is shown to clients and on the home page. (may be important later)
token_limit: 7777
# How many requests a single IP is allowed to put in the queue.
# If an IP tries to put more than this their request will be rejected
# until the other(s) are completed.
ip_in_queue_max: 1
llm_middleware_name: Local LLM Proxy
## Optional

View File

@ -13,6 +13,7 @@ config_default_vars = {
'average_generation_time_mode': 'database',
'info_html': None,
'show_total_output_tokens': True,
'ip_in_queue_max': 3,
}
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']

View File

@ -19,3 +19,4 @@ show_uptime = True
average_generation_time_mode = 'database'
show_total_output_tokens = True
netdata_root = None
ip_in_queue_max = 3

View File

@ -1,5 +1,6 @@
from flask_caching import Cache
from redis import Redis
from redis.typing import FieldT
cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local-llm'})
@ -27,6 +28,15 @@ class RedisWrapper:
def decr(self, key, amount=1):
return self.redis.decr(f"{self.prefix}:{key}", amount)
def sadd(self, key: str, *values: FieldT):
return self.redis.sadd(f"{self.prefix}:{key}", *values)
def srem(self, key: str, *values: FieldT):
return self.redis.srem(f"{self.prefix}:{key}", *values)
def sismember(self, key: str, value: str):
return self.redis.sismember(f"{self.prefix}:{key}", value)
def flush(self):
flushed = []
for key in self.redis.scan_iter(f'{self.prefix}:*'):

View File

@ -2,22 +2,32 @@ import heapq
import threading
import time
from llm_server import opts
from llm_server.llm.generator import generator
from llm_server.routes.cache import redis
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
processing_ips = set()
processing_ips_lock = threading.Lock()
class PriorityQueue:
def __init__(self):
self._queue = []
self._index = 0
self._cv = threading.Condition()
self._ip_count = {}
def put(self, item, priority):
event = DataEvent()
with self._cv:
# Check if the IP is already in the dictionary and if it has reached the limit
if item[1] in self._ip_count and self._ip_count[item[1]] >= opts.ip_in_queue_max:
return None # reject the request
heapq.heappush(self._queue, (-priority, self._index, item, event))
self._index += 1
# Increment the count for this IP
self._ip_count[item[1]] = self._ip_count.get(item[1], 0) + 1
self._cv.notify()
return event
@ -25,7 +35,12 @@ class PriorityQueue:
with self._cv:
while len(self._queue) == 0:
self._cv.wait()
return heapq.heappop(self._queue)
_, _, item, event = heapq.heappop(self._queue)
# Decrement the count for this IP
self._ip_count[item[1]] -= 1
if self._ip_count[item[1]] == 0:
del self._ip_count[item[1]] # Remove the IP from the dictionary if count is 0
return item, event
def __len__(self):
return len(self._queue)
@ -41,10 +56,11 @@ class DataEvent(threading.Event):
def worker():
global active_gen_workers
global processing_ips_lock
while True:
priority, index, (request_json_body, client_ip, token, parameters), event = priority_queue.get()
(request_json_body, client_ip, token, parameters), event = priority_queue.get()
redis.sadd('processing_ips', client_ip)
redis.incr('active_gen_workers')
start_time = time.time()
@ -58,6 +74,7 @@ def worker():
event.data = (success, response, error_msg)
event.set()
redis.srem('processing_ips', client_ip)
redis.decr('active_gen_workers')

View File

@ -51,7 +51,29 @@ def generate():
else:
print(f'Token {token} was given priority {priority}.')
event = priority_queue.put((request_json_body, client_ip, token, parameters), priority)
if not redis.sismember('processing_ips', client_ip):
event = priority_queue.put((request_json_body, client_ip, token, parameters), priority)
else:
event = None
if not event:
log_prompt(client_ip, token, request_json_body['prompt'], '', None, parameters, dict(request.headers), 429)
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.ip_in_queue_max} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
response_json_body = {
'results': [
{
'text': backend_response,
}
],
}
else:
raise Exception
return jsonify({
# 'code': 429,
# 'error': f'no more than {opts.ip_in_queue_max} simultaneous requests per IP',
**response_json_body
}), 200
event.wait()
success, response, error_msg = event.data

View File

@ -52,6 +52,7 @@ opts.show_uptime = config['show_uptime']
opts.backend_url = config['backend_url'].strip('/')
opts.show_total_output_tokens = config['show_total_output_tokens']
opts.netdata_root = config['netdata_root']
opts.ip_in_queue_max = config['ip_in_queue_max']
opts.verify_ssl = config['verify_ssl']
if not opts.verify_ssl: