limit amount of simultaneous requests an IP can make
This commit is contained in:
parent
1a4cb5f786
commit
c16d70a24d
|
@ -15,6 +15,11 @@ concurrent_gens: 1
|
|||
# This number is shown to clients and on the home page. (may be important later)
|
||||
token_limit: 7777
|
||||
|
||||
# How many requests a single IP is allowed to put in the queue.
|
||||
# If an IP tries to put more than this their request will be rejected
|
||||
# until the other(s) are completed.
|
||||
ip_in_queue_max: 1
|
||||
|
||||
llm_middleware_name: Local LLM Proxy
|
||||
|
||||
## Optional
|
||||
|
|
|
@ -13,6 +13,7 @@ config_default_vars = {
|
|||
'average_generation_time_mode': 'database',
|
||||
'info_html': None,
|
||||
'show_total_output_tokens': True,
|
||||
'ip_in_queue_max': 3,
|
||||
}
|
||||
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
||||
|
||||
|
|
|
@ -19,3 +19,4 @@ show_uptime = True
|
|||
average_generation_time_mode = 'database'
|
||||
show_total_output_tokens = True
|
||||
netdata_root = None
|
||||
ip_in_queue_max = 3
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from flask_caching import Cache
|
||||
from redis import Redis
|
||||
from redis.typing import FieldT
|
||||
|
||||
cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local-llm'})
|
||||
|
||||
|
@ -27,6 +28,15 @@ class RedisWrapper:
|
|||
def decr(self, key, amount=1):
|
||||
return self.redis.decr(f"{self.prefix}:{key}", amount)
|
||||
|
||||
def sadd(self, key: str, *values: FieldT):
|
||||
return self.redis.sadd(f"{self.prefix}:{key}", *values)
|
||||
|
||||
def srem(self, key: str, *values: FieldT):
|
||||
return self.redis.srem(f"{self.prefix}:{key}", *values)
|
||||
|
||||
def sismember(self, key: str, value: str):
|
||||
return self.redis.sismember(f"{self.prefix}:{key}", value)
|
||||
|
||||
def flush(self):
|
||||
flushed = []
|
||||
for key in self.redis.scan_iter(f'{self.prefix}:*'):
|
||||
|
|
|
@ -2,22 +2,32 @@ import heapq
|
|||
import threading
|
||||
import time
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.llm.generator import generator
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
|
||||
|
||||
processing_ips = set()
|
||||
processing_ips_lock = threading.Lock()
|
||||
|
||||
|
||||
class PriorityQueue:
|
||||
def __init__(self):
|
||||
self._queue = []
|
||||
self._index = 0
|
||||
self._cv = threading.Condition()
|
||||
self._ip_count = {}
|
||||
|
||||
def put(self, item, priority):
|
||||
event = DataEvent()
|
||||
with self._cv:
|
||||
# Check if the IP is already in the dictionary and if it has reached the limit
|
||||
if item[1] in self._ip_count and self._ip_count[item[1]] >= opts.ip_in_queue_max:
|
||||
return None # reject the request
|
||||
heapq.heappush(self._queue, (-priority, self._index, item, event))
|
||||
self._index += 1
|
||||
# Increment the count for this IP
|
||||
self._ip_count[item[1]] = self._ip_count.get(item[1], 0) + 1
|
||||
self._cv.notify()
|
||||
return event
|
||||
|
||||
|
@ -25,7 +35,12 @@ class PriorityQueue:
|
|||
with self._cv:
|
||||
while len(self._queue) == 0:
|
||||
self._cv.wait()
|
||||
return heapq.heappop(self._queue)
|
||||
_, _, item, event = heapq.heappop(self._queue)
|
||||
# Decrement the count for this IP
|
||||
self._ip_count[item[1]] -= 1
|
||||
if self._ip_count[item[1]] == 0:
|
||||
del self._ip_count[item[1]] # Remove the IP from the dictionary if count is 0
|
||||
return item, event
|
||||
|
||||
def __len__(self):
|
||||
return len(self._queue)
|
||||
|
@ -41,10 +56,11 @@ class DataEvent(threading.Event):
|
|||
|
||||
|
||||
def worker():
|
||||
global active_gen_workers
|
||||
global processing_ips_lock
|
||||
while True:
|
||||
priority, index, (request_json_body, client_ip, token, parameters), event = priority_queue.get()
|
||||
(request_json_body, client_ip, token, parameters), event = priority_queue.get()
|
||||
|
||||
redis.sadd('processing_ips', client_ip)
|
||||
redis.incr('active_gen_workers')
|
||||
|
||||
start_time = time.time()
|
||||
|
@ -58,6 +74,7 @@ def worker():
|
|||
event.data = (success, response, error_msg)
|
||||
event.set()
|
||||
|
||||
redis.srem('processing_ips', client_ip)
|
||||
redis.decr('active_gen_workers')
|
||||
|
||||
|
||||
|
|
|
@ -51,7 +51,29 @@ def generate():
|
|||
else:
|
||||
print(f'Token {token} was given priority {priority}.')
|
||||
|
||||
event = priority_queue.put((request_json_body, client_ip, token, parameters), priority)
|
||||
if not redis.sismember('processing_ips', client_ip):
|
||||
event = priority_queue.put((request_json_body, client_ip, token, parameters), priority)
|
||||
else:
|
||||
event = None
|
||||
if not event:
|
||||
log_prompt(client_ip, token, request_json_body['prompt'], '', None, parameters, dict(request.headers), 429)
|
||||
if opts.mode == 'oobabooga':
|
||||
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.ip_in_queue_max} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
|
||||
response_json_body = {
|
||||
'results': [
|
||||
{
|
||||
'text': backend_response,
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
raise Exception
|
||||
return jsonify({
|
||||
# 'code': 429,
|
||||
# 'error': f'no more than {opts.ip_in_queue_max} simultaneous requests per IP',
|
||||
**response_json_body
|
||||
}), 200
|
||||
|
||||
event.wait()
|
||||
success, response, error_msg = event.data
|
||||
|
||||
|
|
|
@ -52,6 +52,7 @@ opts.show_uptime = config['show_uptime']
|
|||
opts.backend_url = config['backend_url'].strip('/')
|
||||
opts.show_total_output_tokens = config['show_total_output_tokens']
|
||||
opts.netdata_root = config['netdata_root']
|
||||
opts.ip_in_queue_max = config['ip_in_queue_max']
|
||||
|
||||
opts.verify_ssl = config['verify_ssl']
|
||||
if not opts.verify_ssl:
|
||||
|
|
Reference in New Issue