rewrite redis usage
This commit is contained in:
parent
a4a1d6cce6
commit
59f2aac8ad
|
@ -78,9 +78,6 @@ def load_config(config_path, script_path):
|
|||
if config['load_num_prompts']:
|
||||
redis.set('proompts', get_number_of_rows('prompts'))
|
||||
|
||||
redis.set_dict('recent_prompters', {})
|
||||
redis.set_dict('processing_ips', {})
|
||||
redis.set_dict('queued_ip_count', {})
|
||||
redis.set('backend_mode', opts.mode)
|
||||
|
||||
return success, config, msg
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import sys
|
||||
import traceback
|
||||
from typing import Union
|
||||
from typing import Callable, List, Mapping, Union
|
||||
|
||||
import redis as redis_pkg
|
||||
import simplejson as json
|
||||
from flask_caching import Cache
|
||||
from redis import Redis
|
||||
from redis.typing import ExpiryT, FieldT
|
||||
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT
|
||||
|
||||
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
||||
|
||||
|
@ -86,6 +86,52 @@ class RedisWrapper:
|
|||
def llen(self, name: str):
|
||||
return self.redis.llen(self._key(name))
|
||||
|
||||
def zrangebyscore(
|
||||
self,
|
||||
name: KeyT,
|
||||
min: ZScoreBoundT,
|
||||
max: ZScoreBoundT,
|
||||
start: Union[int, None] = None,
|
||||
num: Union[int, None] = None,
|
||||
withscores: bool = False,
|
||||
score_cast_func: Union[type, Callable] = float,
|
||||
):
|
||||
return self.redis.zrangebyscore(self._key(name), min, max, start, num, withscores, score_cast_func)
|
||||
|
||||
def zremrangebyscore(
|
||||
self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT
|
||||
):
|
||||
return self.redis.zremrangebyscore(self._key(name), min, max)
|
||||
|
||||
def hincrby(
|
||||
self, name: str, key: str, amount: int = 1
|
||||
):
|
||||
return self.redis.hincrby(self._key(name), key, amount)
|
||||
|
||||
def hdel(self, name: str, *keys: List):
|
||||
return self.redis.hdel(self._key(name), *keys)
|
||||
|
||||
def hget(
|
||||
self, name: str, key: str
|
||||
):
|
||||
return self.redis.hget(self._key(name), key)
|
||||
|
||||
def zadd(
|
||||
self,
|
||||
name: KeyT,
|
||||
mapping: Mapping[AnyKeyT, EncodableT],
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
ch: bool = False,
|
||||
incr: bool = False,
|
||||
gt: bool = False,
|
||||
lt: bool = False,
|
||||
):
|
||||
return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt)
|
||||
|
||||
def hkeys(self, name: str):
|
||||
return self.redis.hkeys(self._key(name))
|
||||
|
||||
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
|
||||
return self.set(key, json.dumps(dict_value), ex=ex)
|
||||
|
||||
|
|
|
@ -9,21 +9,14 @@ from llm_server import opts
|
|||
from llm_server.routes.cache import redis
|
||||
|
||||
|
||||
def increment_ip_count(client_ip: int, redis_key):
|
||||
ip_count = redis.get_dict(redis_key)
|
||||
ip_count[client_ip] = ip_count.get(client_ip, 0) + 1
|
||||
redis.set_dict(redis_key, ip_count)
|
||||
return ip_count
|
||||
def increment_ip_count(client_ip: str, redis_key):
|
||||
redis.hincrby(redis_key, client_ip, 1)
|
||||
|
||||
|
||||
def decrement_ip_count(client_ip: int, redis_key):
|
||||
ip_count = redis.get_dict(redis_key)
|
||||
if client_ip in ip_count.keys():
|
||||
ip_count[client_ip] -= 1
|
||||
if ip_count[client_ip] == 0:
|
||||
del ip_count[client_ip] # Remove the IP from the dictionary if count is 0
|
||||
redis.set_dict(redis_key, ip_count)
|
||||
return ip_count
|
||||
def decrement_ip_count(client_ip: str, redis_key):
|
||||
new_count = redis.hincrby(redis_key, client_ip, -1)
|
||||
if new_count <= 0:
|
||||
redis.hdel(redis_key, client_ip)
|
||||
|
||||
|
||||
class RedisPriorityQueue:
|
||||
|
@ -57,21 +50,22 @@ class RedisPriorityQueue:
|
|||
return item
|
||||
time.sleep(0.1) # wait for something to be added to the queue
|
||||
|
||||
def increment_ip_count(self, ip, key):
|
||||
self.redis.hincrby(key, ip, 1)
|
||||
def increment_ip_count(self, client_ip: str, redis_key):
|
||||
self.redis.hincrby(redis_key, client_ip, 1)
|
||||
|
||||
def decrement_ip_count(self, ip, key):
|
||||
self.redis.hincrby(key, ip, -1)
|
||||
def decrement_ip_count(self, client_ip: str, redis_key):
|
||||
new_count = self.redis.hincrby(redis_key, client_ip, -1)
|
||||
if new_count <= 0:
|
||||
self.redis.hdel(redis_key, client_ip)
|
||||
|
||||
def __len__(self):
|
||||
return self.redis.zcard('queue')
|
||||
|
||||
def get_ip_count(self, client_ip: str):
|
||||
x = self.redis.hget('queued_ip_count', client_ip)
|
||||
if x:
|
||||
return x.decode('utf-8')
|
||||
else:
|
||||
return x
|
||||
def get_queued_ip_count(self, client_ip: str):
|
||||
q = self.redis.hget('queued_ip_count', client_ip)
|
||||
if not q:
|
||||
return 0
|
||||
return 0
|
||||
|
||||
|
||||
class DataEvent:
|
||||
|
|
|
@ -38,9 +38,7 @@ class RequestHandler:
|
|||
self.backend = get_backend()
|
||||
self.parameters = None
|
||||
self.used = False
|
||||
recent_prompters = redis.get_dict('recent_prompters')
|
||||
recent_prompters[self.client_ip] = (time.time(), self.token)
|
||||
redis.set_dict('recent_prompters', recent_prompters)
|
||||
redis.zadd('recent_prompters', {self.client_ip: time.time()})
|
||||
|
||||
def get_auth_token(self):
|
||||
if self.request_json_body.get('X-API-KEY'):
|
||||
|
@ -191,11 +189,16 @@ class RequestHandler:
|
|||
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
||||
|
||||
def is_client_ratelimited(self) -> bool:
|
||||
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
|
||||
if queued_ip_count < self.token_simultaneous_ip or self.token_priority == 0:
|
||||
queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip))
|
||||
x = redis.hget('processing_ips', self.client_ip)
|
||||
if x:
|
||||
processing_ip = int(x)
|
||||
else:
|
||||
processing_ip = 0
|
||||
if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0:
|
||||
return False
|
||||
else:
|
||||
print(f'Rejecting request from {self.client_ip} - {queued_ip_count} queued + processing.')
|
||||
print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} queued + processing.')
|
||||
return True
|
||||
|
||||
def handle_request(self) -> Tuple[flask.Response, int]:
|
||||
|
|
|
@ -76,7 +76,7 @@ def generate_stats(regen: bool = False):
|
|||
netdata_stats = {}
|
||||
|
||||
base_client_api = redis.get('base_client_api', str)
|
||||
proompters_5_min = redis.get('proompters_5_min', int)
|
||||
proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf'))
|
||||
|
||||
output = {
|
||||
'stats': {
|
||||
|
|
|
@ -150,9 +150,6 @@ def stream(ws):
|
|||
if not chunk:
|
||||
break
|
||||
|
||||
if response:
|
||||
response.close()
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code)
|
||||
|
|
|
@ -2,6 +2,7 @@ import logging
|
|||
import time
|
||||
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.queue import priority_queue
|
||||
|
||||
logger = logging.getLogger('console_printer')
|
||||
if not logger.handlers:
|
||||
|
@ -16,7 +17,9 @@ if not logger.handlers:
|
|||
def console_printer():
|
||||
time.sleep(3)
|
||||
while True:
|
||||
queued_ip_count = sum([v for k, v in redis.get_dict('queued_ip_count').items()])
|
||||
processing_count = sum([v for k, v in redis.get_dict('processing_ips').items()])
|
||||
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {queued_ip_count}')
|
||||
time.sleep(15)
|
||||
processing = redis.hkeys('processing_ips')
|
||||
processing_count = 0
|
||||
for ip in processing:
|
||||
processing_count += int(redis.hget('processing_ips', ip))
|
||||
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)}')
|
||||
time.sleep(1)
|
||||
|
|
|
@ -4,16 +4,6 @@ from llm_server.routes.cache import redis
|
|||
|
||||
|
||||
def recent_prompters_thread():
|
||||
current_time = time.time()
|
||||
recent_prompters = redis.get_dict('recent_prompters')
|
||||
new_recent_prompters = {}
|
||||
|
||||
for ip, (timestamp, token) in recent_prompters.items():
|
||||
if token and token.startswith('SYSTEM__'):
|
||||
continue
|
||||
if current_time - timestamp <= 300:
|
||||
new_recent_prompters[ip] = timestamp, token
|
||||
|
||||
redis.set_dict('recent_prompters', new_recent_prompters)
|
||||
redis.set('proompters_5_min', len(new_recent_prompters))
|
||||
five_minutes_ago = time.time() - 5 * 60
|
||||
redis.zremrangebyscore('recent_prompters', '-inf', five_minutes_ago)
|
||||
time.sleep(1)
|
||||
|
|
Reference in New Issue