rewrite redis usage
This commit is contained in:
parent
a4a1d6cce6
commit
59f2aac8ad
|
@ -78,9 +78,6 @@ def load_config(config_path, script_path):
|
||||||
if config['load_num_prompts']:
|
if config['load_num_prompts']:
|
||||||
redis.set('proompts', get_number_of_rows('prompts'))
|
redis.set('proompts', get_number_of_rows('prompts'))
|
||||||
|
|
||||||
redis.set_dict('recent_prompters', {})
|
|
||||||
redis.set_dict('processing_ips', {})
|
|
||||||
redis.set_dict('queued_ip_count', {})
|
|
||||||
redis.set('backend_mode', opts.mode)
|
redis.set('backend_mode', opts.mode)
|
||||||
|
|
||||||
return success, config, msg
|
return success, config, msg
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Union
|
from typing import Callable, List, Mapping, Union
|
||||||
|
|
||||||
import redis as redis_pkg
|
import redis as redis_pkg
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
from flask_caching import Cache
|
from flask_caching import Cache
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
from redis.typing import ExpiryT, FieldT
|
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT
|
||||||
|
|
||||||
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
||||||
|
|
||||||
|
@ -86,6 +86,52 @@ class RedisWrapper:
|
||||||
def llen(self, name: str):
|
def llen(self, name: str):
|
||||||
return self.redis.llen(self._key(name))
|
return self.redis.llen(self._key(name))
|
||||||
|
|
||||||
|
def zrangebyscore(
|
||||||
|
self,
|
||||||
|
name: KeyT,
|
||||||
|
min: ZScoreBoundT,
|
||||||
|
max: ZScoreBoundT,
|
||||||
|
start: Union[int, None] = None,
|
||||||
|
num: Union[int, None] = None,
|
||||||
|
withscores: bool = False,
|
||||||
|
score_cast_func: Union[type, Callable] = float,
|
||||||
|
):
|
||||||
|
return self.redis.zrangebyscore(self._key(name), min, max, start, num, withscores, score_cast_func)
|
||||||
|
|
||||||
|
def zremrangebyscore(
|
||||||
|
self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT
|
||||||
|
):
|
||||||
|
return self.redis.zremrangebyscore(self._key(name), min, max)
|
||||||
|
|
||||||
|
def hincrby(
|
||||||
|
self, name: str, key: str, amount: int = 1
|
||||||
|
):
|
||||||
|
return self.redis.hincrby(self._key(name), key, amount)
|
||||||
|
|
||||||
|
def hdel(self, name: str, *keys: List):
|
||||||
|
return self.redis.hdel(self._key(name), *keys)
|
||||||
|
|
||||||
|
def hget(
|
||||||
|
self, name: str, key: str
|
||||||
|
):
|
||||||
|
return self.redis.hget(self._key(name), key)
|
||||||
|
|
||||||
|
def zadd(
|
||||||
|
self,
|
||||||
|
name: KeyT,
|
||||||
|
mapping: Mapping[AnyKeyT, EncodableT],
|
||||||
|
nx: bool = False,
|
||||||
|
xx: bool = False,
|
||||||
|
ch: bool = False,
|
||||||
|
incr: bool = False,
|
||||||
|
gt: bool = False,
|
||||||
|
lt: bool = False,
|
||||||
|
):
|
||||||
|
return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt)
|
||||||
|
|
||||||
|
def hkeys(self, name: str):
|
||||||
|
return self.redis.hkeys(self._key(name))
|
||||||
|
|
||||||
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
|
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
|
||||||
return self.set(key, json.dumps(dict_value), ex=ex)
|
return self.set(key, json.dumps(dict_value), ex=ex)
|
||||||
|
|
||||||
|
|
|
@ -9,21 +9,14 @@ from llm_server import opts
|
||||||
from llm_server.routes.cache import redis
|
from llm_server.routes.cache import redis
|
||||||
|
|
||||||
|
|
||||||
def increment_ip_count(client_ip: int, redis_key):
|
def increment_ip_count(client_ip: str, redis_key):
|
||||||
ip_count = redis.get_dict(redis_key)
|
redis.hincrby(redis_key, client_ip, 1)
|
||||||
ip_count[client_ip] = ip_count.get(client_ip, 0) + 1
|
|
||||||
redis.set_dict(redis_key, ip_count)
|
|
||||||
return ip_count
|
|
||||||
|
|
||||||
|
|
||||||
def decrement_ip_count(client_ip: int, redis_key):
|
def decrement_ip_count(client_ip: str, redis_key):
|
||||||
ip_count = redis.get_dict(redis_key)
|
new_count = redis.hincrby(redis_key, client_ip, -1)
|
||||||
if client_ip in ip_count.keys():
|
if new_count <= 0:
|
||||||
ip_count[client_ip] -= 1
|
redis.hdel(redis_key, client_ip)
|
||||||
if ip_count[client_ip] == 0:
|
|
||||||
del ip_count[client_ip] # Remove the IP from the dictionary if count is 0
|
|
||||||
redis.set_dict(redis_key, ip_count)
|
|
||||||
return ip_count
|
|
||||||
|
|
||||||
|
|
||||||
class RedisPriorityQueue:
|
class RedisPriorityQueue:
|
||||||
|
@ -57,21 +50,22 @@ class RedisPriorityQueue:
|
||||||
return item
|
return item
|
||||||
time.sleep(0.1) # wait for something to be added to the queue
|
time.sleep(0.1) # wait for something to be added to the queue
|
||||||
|
|
||||||
def increment_ip_count(self, ip, key):
|
def increment_ip_count(self, client_ip: str, redis_key):
|
||||||
self.redis.hincrby(key, ip, 1)
|
self.redis.hincrby(redis_key, client_ip, 1)
|
||||||
|
|
||||||
def decrement_ip_count(self, ip, key):
|
def decrement_ip_count(self, client_ip: str, redis_key):
|
||||||
self.redis.hincrby(key, ip, -1)
|
new_count = self.redis.hincrby(redis_key, client_ip, -1)
|
||||||
|
if new_count <= 0:
|
||||||
|
self.redis.hdel(redis_key, client_ip)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.redis.zcard('queue')
|
return self.redis.zcard('queue')
|
||||||
|
|
||||||
def get_ip_count(self, client_ip: str):
|
def get_queued_ip_count(self, client_ip: str):
|
||||||
x = self.redis.hget('queued_ip_count', client_ip)
|
q = self.redis.hget('queued_ip_count', client_ip)
|
||||||
if x:
|
if not q:
|
||||||
return x.decode('utf-8')
|
return 0
|
||||||
else:
|
return 0
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class DataEvent:
|
class DataEvent:
|
||||||
|
|
|
@ -38,9 +38,7 @@ class RequestHandler:
|
||||||
self.backend = get_backend()
|
self.backend = get_backend()
|
||||||
self.parameters = None
|
self.parameters = None
|
||||||
self.used = False
|
self.used = False
|
||||||
recent_prompters = redis.get_dict('recent_prompters')
|
redis.zadd('recent_prompters', {self.client_ip: time.time()})
|
||||||
recent_prompters[self.client_ip] = (time.time(), self.token)
|
|
||||||
redis.set_dict('recent_prompters', recent_prompters)
|
|
||||||
|
|
||||||
def get_auth_token(self):
|
def get_auth_token(self):
|
||||||
if self.request_json_body.get('X-API-KEY'):
|
if self.request_json_body.get('X-API-KEY'):
|
||||||
|
@ -191,11 +189,16 @@ class RequestHandler:
|
||||||
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
||||||
|
|
||||||
def is_client_ratelimited(self) -> bool:
|
def is_client_ratelimited(self) -> bool:
|
||||||
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
|
queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip))
|
||||||
if queued_ip_count < self.token_simultaneous_ip or self.token_priority == 0:
|
x = redis.hget('processing_ips', self.client_ip)
|
||||||
|
if x:
|
||||||
|
processing_ip = int(x)
|
||||||
|
else:
|
||||||
|
processing_ip = 0
|
||||||
|
if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
print(f'Rejecting request from {self.client_ip} - {queued_ip_count} queued + processing.')
|
print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} queued + processing.')
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def handle_request(self) -> Tuple[flask.Response, int]:
|
def handle_request(self) -> Tuple[flask.Response, int]:
|
||||||
|
|
|
@ -76,7 +76,7 @@ def generate_stats(regen: bool = False):
|
||||||
netdata_stats = {}
|
netdata_stats = {}
|
||||||
|
|
||||||
base_client_api = redis.get('base_client_api', str)
|
base_client_api = redis.get('base_client_api', str)
|
||||||
proompters_5_min = redis.get('proompters_5_min', int)
|
proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf'))
|
||||||
|
|
||||||
output = {
|
output = {
|
||||||
'stats': {
|
'stats': {
|
||||||
|
|
|
@ -150,9 +150,6 @@ def stream(ws):
|
||||||
if not chunk:
|
if not chunk:
|
||||||
break
|
break
|
||||||
|
|
||||||
if response:
|
|
||||||
response.close()
|
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code)
|
log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code)
|
||||||
|
|
|
@ -2,6 +2,7 @@ import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from llm_server.routes.cache import redis
|
from llm_server.routes.cache import redis
|
||||||
|
from llm_server.routes.queue import priority_queue
|
||||||
|
|
||||||
logger = logging.getLogger('console_printer')
|
logger = logging.getLogger('console_printer')
|
||||||
if not logger.handlers:
|
if not logger.handlers:
|
||||||
|
@ -16,7 +17,9 @@ if not logger.handlers:
|
||||||
def console_printer():
|
def console_printer():
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
while True:
|
while True:
|
||||||
queued_ip_count = sum([v for k, v in redis.get_dict('queued_ip_count').items()])
|
processing = redis.hkeys('processing_ips')
|
||||||
processing_count = sum([v for k, v in redis.get_dict('processing_ips').items()])
|
processing_count = 0
|
||||||
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {queued_ip_count}')
|
for ip in processing:
|
||||||
time.sleep(15)
|
processing_count += int(redis.hget('processing_ips', ip))
|
||||||
|
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)}')
|
||||||
|
time.sleep(1)
|
||||||
|
|
|
@ -4,16 +4,6 @@ from llm_server.routes.cache import redis
|
||||||
|
|
||||||
|
|
||||||
def recent_prompters_thread():
|
def recent_prompters_thread():
|
||||||
current_time = time.time()
|
five_minutes_ago = time.time() - 5 * 60
|
||||||
recent_prompters = redis.get_dict('recent_prompters')
|
redis.zremrangebyscore('recent_prompters', '-inf', five_minutes_ago)
|
||||||
new_recent_prompters = {}
|
|
||||||
|
|
||||||
for ip, (timestamp, token) in recent_prompters.items():
|
|
||||||
if token and token.startswith('SYSTEM__'):
|
|
||||||
continue
|
|
||||||
if current_time - timestamp <= 300:
|
|
||||||
new_recent_prompters[ip] = timestamp, token
|
|
||||||
|
|
||||||
redis.set_dict('recent_prompters', new_recent_prompters)
|
|
||||||
redis.set('proompters_5_min', len(new_recent_prompters))
|
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
Reference in New Issue