rewrite redis usage

This commit is contained in:
Cyberes 2023-09-28 03:44:30 -06:00
parent a4a1d6cce6
commit 59f2aac8ad
8 changed files with 84 additions and 54 deletions

View File

@ -78,9 +78,6 @@ def load_config(config_path, script_path):
if config['load_num_prompts']: if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts')) redis.set('proompts', get_number_of_rows('prompts'))
redis.set_dict('recent_prompters', {})
redis.set_dict('processing_ips', {})
redis.set_dict('queued_ip_count', {})
redis.set('backend_mode', opts.mode) redis.set('backend_mode', opts.mode)
return success, config, msg return success, config, msg

View File

@ -1,12 +1,12 @@
import sys import sys
import traceback import traceback
from typing import Union from typing import Callable, List, Mapping, Union
import redis as redis_pkg import redis as redis_pkg
import simplejson as json import simplejson as json
from flask_caching import Cache from flask_caching import Cache
from redis import Redis from redis import Redis
from redis.typing import ExpiryT, FieldT from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'}) flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
@ -86,6 +86,52 @@ class RedisWrapper:
def llen(self, name: str): def llen(self, name: str):
return self.redis.llen(self._key(name)) return self.redis.llen(self._key(name))
def zrangebyscore(
self,
name: KeyT,
min: ZScoreBoundT,
max: ZScoreBoundT,
start: Union[int, None] = None,
num: Union[int, None] = None,
withscores: bool = False,
score_cast_func: Union[type, Callable] = float,
):
return self.redis.zrangebyscore(self._key(name), min, max, start, num, withscores, score_cast_func)
def zremrangebyscore(
self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT
):
return self.redis.zremrangebyscore(self._key(name), min, max)
def hincrby(
self, name: str, key: str, amount: int = 1
):
return self.redis.hincrby(self._key(name), key, amount)
def hdel(self, name: str, *keys: List):
return self.redis.hdel(self._key(name), *keys)
def hget(
self, name: str, key: str
):
return self.redis.hget(self._key(name), key)
def zadd(
self,
name: KeyT,
mapping: Mapping[AnyKeyT, EncodableT],
nx: bool = False,
xx: bool = False,
ch: bool = False,
incr: bool = False,
gt: bool = False,
lt: bool = False,
):
return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt)
def hkeys(self, name: str):
return self.redis.hkeys(self._key(name))
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None): def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
return self.set(key, json.dumps(dict_value), ex=ex) return self.set(key, json.dumps(dict_value), ex=ex)

View File

@ -9,21 +9,14 @@ from llm_server import opts
from llm_server.routes.cache import redis from llm_server.routes.cache import redis
def increment_ip_count(client_ip: int, redis_key): def increment_ip_count(client_ip: str, redis_key):
ip_count = redis.get_dict(redis_key) redis.hincrby(redis_key, client_ip, 1)
ip_count[client_ip] = ip_count.get(client_ip, 0) + 1
redis.set_dict(redis_key, ip_count)
return ip_count
def decrement_ip_count(client_ip: int, redis_key): def decrement_ip_count(client_ip: str, redis_key):
ip_count = redis.get_dict(redis_key) new_count = redis.hincrby(redis_key, client_ip, -1)
if client_ip in ip_count.keys(): if new_count <= 0:
ip_count[client_ip] -= 1 redis.hdel(redis_key, client_ip)
if ip_count[client_ip] == 0:
del ip_count[client_ip] # Remove the IP from the dictionary if count is 0
redis.set_dict(redis_key, ip_count)
return ip_count
class RedisPriorityQueue: class RedisPriorityQueue:
@ -57,21 +50,22 @@ class RedisPriorityQueue:
return item return item
time.sleep(0.1) # wait for something to be added to the queue time.sleep(0.1) # wait for something to be added to the queue
def increment_ip_count(self, ip, key): def increment_ip_count(self, client_ip: str, redis_key):
self.redis.hincrby(key, ip, 1) self.redis.hincrby(redis_key, client_ip, 1)
def decrement_ip_count(self, ip, key): def decrement_ip_count(self, client_ip: str, redis_key):
self.redis.hincrby(key, ip, -1) new_count = self.redis.hincrby(redis_key, client_ip, -1)
if new_count <= 0:
self.redis.hdel(redis_key, client_ip)
def __len__(self): def __len__(self):
return self.redis.zcard('queue') return self.redis.zcard('queue')
def get_ip_count(self, client_ip: str): def get_queued_ip_count(self, client_ip: str):
x = self.redis.hget('queued_ip_count', client_ip) q = self.redis.hget('queued_ip_count', client_ip)
if x: if not q:
return x.decode('utf-8') return 0
else: return 0
return x
class DataEvent: class DataEvent:

View File

@ -38,9 +38,7 @@ class RequestHandler:
self.backend = get_backend() self.backend = get_backend()
self.parameters = None self.parameters = None
self.used = False self.used = False
recent_prompters = redis.get_dict('recent_prompters') redis.zadd('recent_prompters', {self.client_ip: time.time()})
recent_prompters[self.client_ip] = (time.time(), self.token)
redis.set_dict('recent_prompters', recent_prompters)
def get_auth_token(self): def get_auth_token(self):
if self.request_json_body.get('X-API-KEY'): if self.request_json_body.get('X-API-KEY'):
@ -191,11 +189,16 @@ class RequestHandler:
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers)) return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
def is_client_ratelimited(self) -> bool: def is_client_ratelimited(self) -> bool:
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0) queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip))
if queued_ip_count < self.token_simultaneous_ip or self.token_priority == 0: x = redis.hget('processing_ips', self.client_ip)
if x:
processing_ip = int(x)
else:
processing_ip = 0
if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0:
return False return False
else: else:
print(f'Rejecting request from {self.client_ip} - {queued_ip_count} queued + processing.') print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} queued + processing.')
return True return True
def handle_request(self) -> Tuple[flask.Response, int]: def handle_request(self) -> Tuple[flask.Response, int]:

View File

@ -76,7 +76,7 @@ def generate_stats(regen: bool = False):
netdata_stats = {} netdata_stats = {}
base_client_api = redis.get('base_client_api', str) base_client_api = redis.get('base_client_api', str)
proompters_5_min = redis.get('proompters_5_min', int) proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf'))
output = { output = {
'stats': { 'stats': {

View File

@ -150,9 +150,6 @@ def stream(ws):
if not chunk: if not chunk:
break break
if response:
response.close()
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code) log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code)

View File

@ -2,6 +2,7 @@ import logging
import time import time
from llm_server.routes.cache import redis from llm_server.routes.cache import redis
from llm_server.routes.queue import priority_queue
logger = logging.getLogger('console_printer') logger = logging.getLogger('console_printer')
if not logger.handlers: if not logger.handlers:
@ -16,7 +17,9 @@ if not logger.handlers:
def console_printer(): def console_printer():
time.sleep(3) time.sleep(3)
while True: while True:
queued_ip_count = sum([v for k, v in redis.get_dict('queued_ip_count').items()]) processing = redis.hkeys('processing_ips')
processing_count = sum([v for k, v in redis.get_dict('processing_ips').items()]) processing_count = 0
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {queued_ip_count}') for ip in processing:
time.sleep(15) processing_count += int(redis.hget('processing_ips', ip))
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)}')
time.sleep(1)

View File

@ -4,16 +4,6 @@ from llm_server.routes.cache import redis
def recent_prompters_thread(): def recent_prompters_thread():
current_time = time.time() five_minutes_ago = time.time() - 5 * 60
recent_prompters = redis.get_dict('recent_prompters') redis.zremrangebyscore('recent_prompters', '-inf', five_minutes_ago)
new_recent_prompters = {}
for ip, (timestamp, token) in recent_prompters.items():
if token and token.startswith('SYSTEM__'):
continue
if current_time - timestamp <= 300:
new_recent_prompters[ip] = timestamp, token
redis.set_dict('recent_prompters', new_recent_prompters)
redis.set('proompters_5_min', len(new_recent_prompters))
time.sleep(1) time.sleep(1)