allow setting simultaneous IP limit per-token, fix token use tracker, fix tokens on streaming

This commit is contained in:
Cyberes 2023-09-25 00:55:20 -06:00
parent d2651756df
commit 6459a1c91b
5 changed files with 39 additions and 27 deletions

View File

@ -31,6 +31,7 @@ def create_db():
UNIQUE (token),
type TEXT NOT NULL,
priority INTEGER DEFAULT 9999,
simultaneous_ip INTEGER DEFAULT NULL,
uses INTEGER DEFAULT 0,
max_uses INTEGER,
expire INTEGER,
@ -39,4 +40,3 @@ def create_db():
''')
conn.commit()
cursor.close()

View File

@ -30,6 +30,9 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
# TODO: test and verify this works as expected
response = None
if token:
increment_token_uses(token)
timestamp = int(time.time())
conn = db_pool.connection()
cursor = conn.cursor()
@ -61,21 +64,6 @@ def is_valid_api_key(api_key):
cursor.close()
def increment_uses(api_key):
conn = db_pool.connection()
cursor = conn.cursor()
try:
cursor.execute("SELECT token FROM token_auth WHERE token = %s", (api_key,))
row = cursor.fetchone()
if row is not None:
cursor.execute("UPDATE token_auth SET uses = COALESCE(uses, 0) + 1 WHERE token = %s", (api_key,))
return True
conn.commit()
return False
finally:
cursor.close()
def get_number_of_rows(table_name):
conn = db_pool.connection()
cursor = conn.cursor()
@ -114,7 +102,7 @@ def weighted_average_column_for_model(table_name, column_name, model_name, backe
cursor = conn.cursor()
try:
try:
cursor.execute(f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s ORDER BY id DESC", (model_name, backend_name, backend_url,))
cursor.execute(f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s AND token NOT LIKE 'SYSTEM__%%' ORDER BY id DESC", (model_name, backend_name, backend_url,))
results = cursor.fetchall()
except Exception:
traceback.print_exc()
@ -162,3 +150,12 @@ def get_distinct_ips_24h():
return result[0] if result else 0
finally:
cursor.close()
def increment_token_uses(token):
conn = db_pool.connection()
cursor = conn.cursor()
try:
cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,))
finally:
cursor.close()

View File

@ -32,8 +32,8 @@ class RequestHandler:
self.start_time = time.time()
self.client_ip = self.get_client_ip()
self.token = self.request.headers.get('X-Api-Key')
self.priority = self.get_priority()
self.token = self.get_auth_token()
self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit()
self.backend = get_backend()
self.parameters = None
self.used = False
@ -41,6 +41,13 @@ class RequestHandler:
recent_prompters[self.client_ip] = time.time()
redis.set_dict('recent_prompters', recent_prompters)
def get_auth_token(self):
websocket_key = self.request_json_body.get('X-API-KEY')
if websocket_key:
return websocket_key
else:
return self.request.headers.get('X-Api-Key')
def get_client_ip(self):
if self.request.headers.get('cf-connecting-ip'):
return self.request.headers.get('cf-connecting-ip')
@ -49,19 +56,23 @@ class RequestHandler:
else:
return self.request.remote_addr
def get_priority(self):
def get_token_ratelimit(self):
priority = DEFAULT_PRIORITY
simultaneous_ip = opts.simultaneous_requests_per_ip
if self.token:
conn = db_pool.connection()
cursor = conn.cursor()
try:
cursor.execute("SELECT priority FROM token_auth WHERE token = %s", (self.token,))
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (self.token,))
result = cursor.fetchone()
if result:
return result[0]
priority, simultaneous_ip = result
if simultaneous_ip is None:
# No ratelimit for this token if null
simultaneous_ip = 999999999
finally:
cursor.close()
return DEFAULT_PRIORITY
return priority, simultaneous_ip
def get_parameters(self):
if self.request_json_body.get('max_tokens'):
@ -119,7 +130,7 @@ class RequestHandler:
if not request_valid:
return (False, None, None, 0), invalid_response
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.token_priority)
else:
event = None
@ -178,7 +189,7 @@ class RequestHandler:
def is_client_ratelimited(self) -> bool:
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0:
if queued_ip_count < self.token_simultaneous_ip or self.token_priority == 0:
return False
else:
return True

View File

@ -8,7 +8,7 @@ from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler
from ... import opts
from ...database.database import log_prompt
from ...database.database import increment_token_uses, log_prompt
from ...llm.generator import generator
from ...llm.vllm import tokenize
from ...stream import sock
@ -40,6 +40,7 @@ def stream(ws):
raise NotImplementedError
handler = OobaRequestHandler(request, request_json_body)
token = request_json_body.get('X-API-KEY')
generated_text = ''
input_prompt = None
response_status_code = 0

View File

@ -15,6 +15,9 @@ from llm_server.llm import get_token_count
from llm_server.routes.openai import openai_bp
from llm_server.routes.server_error import handle_server_error
# TODO: allow setting more custom ratelimits per-token
# TODO: add more excluding to SYSTEM__ tokens
try:
import vllm
except ModuleNotFoundError as e: