allow setting simultaneous IP limit per-token, fix token use tracker, fix tokens on streaming
This commit is contained in:
parent
d2651756df
commit
6459a1c91b
|
@ -31,6 +31,7 @@ def create_db():
|
|||
UNIQUE (token),
|
||||
type TEXT NOT NULL,
|
||||
priority INTEGER DEFAULT 9999,
|
||||
simultaneous_ip INTEGER DEFAULT NULL,
|
||||
uses INTEGER DEFAULT 0,
|
||||
max_uses INTEGER,
|
||||
expire INTEGER,
|
||||
|
@ -39,4 +40,3 @@ def create_db():
|
|||
''')
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
|
||||
|
|
|
@ -30,6 +30,9 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
|
|||
# TODO: test and verify this works as expected
|
||||
response = None
|
||||
|
||||
if token:
|
||||
increment_token_uses(token)
|
||||
|
||||
timestamp = int(time.time())
|
||||
conn = db_pool.connection()
|
||||
cursor = conn.cursor()
|
||||
|
@ -61,21 +64,6 @@ def is_valid_api_key(api_key):
|
|||
cursor.close()
|
||||
|
||||
|
||||
def increment_uses(api_key):
|
||||
conn = db_pool.connection()
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute("SELECT token FROM token_auth WHERE token = %s", (api_key,))
|
||||
row = cursor.fetchone()
|
||||
if row is not None:
|
||||
cursor.execute("UPDATE token_auth SET uses = COALESCE(uses, 0) + 1 WHERE token = %s", (api_key,))
|
||||
return True
|
||||
conn.commit()
|
||||
return False
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
def get_number_of_rows(table_name):
|
||||
conn = db_pool.connection()
|
||||
cursor = conn.cursor()
|
||||
|
@ -114,7 +102,7 @@ def weighted_average_column_for_model(table_name, column_name, model_name, backe
|
|||
cursor = conn.cursor()
|
||||
try:
|
||||
try:
|
||||
cursor.execute(f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s ORDER BY id DESC", (model_name, backend_name, backend_url,))
|
||||
cursor.execute(f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s AND token NOT LIKE 'SYSTEM__%%' ORDER BY id DESC", (model_name, backend_name, backend_url,))
|
||||
results = cursor.fetchall()
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
@ -162,3 +150,12 @@ def get_distinct_ips_24h():
|
|||
return result[0] if result else 0
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
def increment_token_uses(token):
|
||||
conn = db_pool.connection()
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,))
|
||||
finally:
|
||||
cursor.close()
|
||||
|
|
|
@ -32,8 +32,8 @@ class RequestHandler:
|
|||
|
||||
self.start_time = time.time()
|
||||
self.client_ip = self.get_client_ip()
|
||||
self.token = self.request.headers.get('X-Api-Key')
|
||||
self.priority = self.get_priority()
|
||||
self.token = self.get_auth_token()
|
||||
self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit()
|
||||
self.backend = get_backend()
|
||||
self.parameters = None
|
||||
self.used = False
|
||||
|
@ -41,6 +41,13 @@ class RequestHandler:
|
|||
recent_prompters[self.client_ip] = time.time()
|
||||
redis.set_dict('recent_prompters', recent_prompters)
|
||||
|
||||
def get_auth_token(self):
|
||||
websocket_key = self.request_json_body.get('X-API-KEY')
|
||||
if websocket_key:
|
||||
return websocket_key
|
||||
else:
|
||||
return self.request.headers.get('X-Api-Key')
|
||||
|
||||
def get_client_ip(self):
|
||||
if self.request.headers.get('cf-connecting-ip'):
|
||||
return self.request.headers.get('cf-connecting-ip')
|
||||
|
@ -49,19 +56,23 @@ class RequestHandler:
|
|||
else:
|
||||
return self.request.remote_addr
|
||||
|
||||
def get_priority(self):
|
||||
def get_token_ratelimit(self):
|
||||
priority = DEFAULT_PRIORITY
|
||||
simultaneous_ip = opts.simultaneous_requests_per_ip
|
||||
if self.token:
|
||||
conn = db_pool.connection()
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute("SELECT priority FROM token_auth WHERE token = %s", (self.token,))
|
||||
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (self.token,))
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result:
|
||||
return result[0]
|
||||
priority, simultaneous_ip = result
|
||||
if simultaneous_ip is None:
|
||||
# No ratelimit for this token if null
|
||||
simultaneous_ip = 999999999
|
||||
finally:
|
||||
cursor.close()
|
||||
return DEFAULT_PRIORITY
|
||||
return priority, simultaneous_ip
|
||||
|
||||
def get_parameters(self):
|
||||
if self.request_json_body.get('max_tokens'):
|
||||
|
@ -119,7 +130,7 @@ class RequestHandler:
|
|||
if not request_valid:
|
||||
return (False, None, None, 0), invalid_response
|
||||
|
||||
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
|
||||
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.token_priority)
|
||||
else:
|
||||
event = None
|
||||
|
||||
|
@ -178,7 +189,7 @@ class RequestHandler:
|
|||
|
||||
def is_client_ratelimited(self) -> bool:
|
||||
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
|
||||
if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0:
|
||||
if queued_ip_count < self.token_simultaneous_ip or self.token_priority == 0:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
|
|
@ -8,7 +8,7 @@ from ..helpers.client import format_sillytavern_err
|
|||
from ..helpers.http import validate_json
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ... import opts
|
||||
from ...database.database import log_prompt
|
||||
from ...database.database import increment_token_uses, log_prompt
|
||||
from ...llm.generator import generator
|
||||
from ...llm.vllm import tokenize
|
||||
from ...stream import sock
|
||||
|
@ -40,6 +40,7 @@ def stream(ws):
|
|||
raise NotImplementedError
|
||||
|
||||
handler = OobaRequestHandler(request, request_json_body)
|
||||
token = request_json_body.get('X-API-KEY')
|
||||
generated_text = ''
|
||||
input_prompt = None
|
||||
response_status_code = 0
|
||||
|
|
|
@ -15,6 +15,9 @@ from llm_server.llm import get_token_count
|
|||
from llm_server.routes.openai import openai_bp
|
||||
from llm_server.routes.server_error import handle_server_error
|
||||
|
||||
# TODO: allow setting more custom ratelimits per-token
|
||||
# TODO: add more excluding to SYSTEM__ tokens
|
||||
|
||||
try:
|
||||
import vllm
|
||||
except ModuleNotFoundError as e:
|
||||
|
|
Reference in New Issue