2023-08-30 18:53:26 -06:00
|
|
|
import time
|
2023-09-14 14:05:50 -06:00
|
|
|
from typing import Tuple, Union
|
2023-08-30 18:53:26 -06:00
|
|
|
|
2023-09-13 11:22:33 -06:00
|
|
|
import flask
|
2023-09-17 18:55:36 -06:00
|
|
|
from flask import Response, request
|
2023-09-13 11:22:33 -06:00
|
|
|
|
2023-08-30 18:53:26 -06:00
|
|
|
from llm_server import opts
|
2023-09-29 00:09:44 -06:00
|
|
|
from llm_server.cluster.backend import get_a_cluster_backend
|
|
|
|
from llm_server.cluster.cluster_config import cluster_config
|
|
|
|
from llm_server.custom_redis import redis
|
2023-09-26 23:59:22 -06:00
|
|
|
from llm_server.database.conn import database
|
2023-09-20 20:30:31 -06:00
|
|
|
from llm_server.database.database import log_prompt
|
2023-09-23 23:24:08 -06:00
|
|
|
from llm_server.helpers import auto_set_base_client_api
|
2023-09-12 16:40:09 -06:00
|
|
|
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
2023-09-11 20:47:19 -06:00
|
|
|
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
2023-09-26 22:09:11 -06:00
|
|
|
from llm_server.routes.auth import parse_token
|
2023-09-17 18:55:36 -06:00
|
|
|
from llm_server.routes.helpers.http import require_api_key, validate_json
|
2023-09-14 14:05:50 -06:00
|
|
|
from llm_server.routes.queue import priority_queue
|
2023-08-30 18:53:26 -06:00
|
|
|
|
|
|
|
DEFAULT_PRIORITY = 9999
|
|
|
|
|
|
|
|
|
2023-09-12 16:40:09 -06:00
|
|
|
class RequestHandler:
|
2023-09-23 17:57:23 -06:00
|
|
|
def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None):
|
2023-08-30 18:53:26 -06:00
|
|
|
self.request = incoming_request
|
2023-09-26 22:09:11 -06:00
|
|
|
self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
|
2023-09-23 17:57:23 -06:00
|
|
|
|
2023-09-25 09:32:23 -06:00
|
|
|
# Routes need to validate it, here we just load it
|
2023-09-23 17:57:23 -06:00
|
|
|
if incoming_json:
|
|
|
|
self.request_valid_json, self.request_json_body = validate_json(incoming_json)
|
|
|
|
else:
|
|
|
|
self.request_valid_json, self.request_json_body = validate_json(self.request)
|
|
|
|
if not self.request_valid_json:
|
2023-09-25 09:32:23 -06:00
|
|
|
raise Exception(f'Not valid JSON. Routes are supposed to reject invalid JSON.')
|
2023-09-23 17:57:23 -06:00
|
|
|
|
2023-08-30 18:53:26 -06:00
|
|
|
self.start_time = time.time()
|
|
|
|
self.client_ip = self.get_client_ip()
|
2023-09-25 00:55:20 -06:00
|
|
|
self.token = self.get_auth_token()
|
|
|
|
self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit()
|
2023-09-29 00:09:44 -06:00
|
|
|
self.cluster_backend = get_a_cluster_backend()
|
|
|
|
self.cluster_backend_info = cluster_config.get_backend(self.cluster_backend)
|
|
|
|
self.backend = get_backend_handler(self.cluster_backend)
|
2023-09-14 17:38:20 -06:00
|
|
|
self.parameters = None
|
2023-09-12 16:40:09 -06:00
|
|
|
self.used = False
|
2023-09-28 03:44:30 -06:00
|
|
|
redis.zadd('recent_prompters', {self.client_ip: time.time()})
|
2023-08-30 18:53:26 -06:00
|
|
|
|
2023-09-25 00:55:20 -06:00
|
|
|
def get_auth_token(self):
|
2023-09-26 22:09:11 -06:00
|
|
|
if self.request_json_body.get('X-API-KEY'):
|
2023-09-26 22:49:53 -06:00
|
|
|
return self.request_json_body['X-API-KEY']
|
2023-09-26 22:09:11 -06:00
|
|
|
elif self.request.headers.get('X-Api-Key'):
|
|
|
|
return self.request.headers['X-Api-Key']
|
2023-09-26 22:49:53 -06:00
|
|
|
elif self.request.headers.get('Authorization'):
|
2023-09-26 22:09:11 -06:00
|
|
|
return parse_token(self.request.headers['Authorization'])
|
2023-09-25 00:55:20 -06:00
|
|
|
|
2023-08-30 18:53:26 -06:00
|
|
|
def get_client_ip(self):
|
2023-09-25 09:32:23 -06:00
|
|
|
if self.request.headers.get('X-Connecting-IP'):
|
|
|
|
return self.request.headers.get('X-Connecting-IP')
|
|
|
|
elif self.request.headers.get('Cf-Connecting-Ip'):
|
|
|
|
return self.request.headers.get('Cf-Connecting-Ip')
|
|
|
|
elif self.request.headers.get('X-Forwarded-For'):
|
|
|
|
return self.request.headers.get('X-Forwarded-For').split(',')[0]
|
2023-08-30 18:53:26 -06:00
|
|
|
else:
|
|
|
|
return self.request.remote_addr
|
|
|
|
|
2023-09-25 00:55:20 -06:00
|
|
|
def get_token_ratelimit(self):
|
|
|
|
priority = DEFAULT_PRIORITY
|
|
|
|
simultaneous_ip = opts.simultaneous_requests_per_ip
|
2023-08-30 18:53:26 -06:00
|
|
|
if self.token:
|
2023-09-26 23:59:22 -06:00
|
|
|
cursor = database.cursor()
|
2023-09-20 21:19:26 -06:00
|
|
|
try:
|
2023-09-25 00:55:20 -06:00
|
|
|
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (self.token,))
|
2023-09-20 21:19:26 -06:00
|
|
|
result = cursor.fetchone()
|
|
|
|
if result:
|
2023-09-25 00:55:20 -06:00
|
|
|
priority, simultaneous_ip = result
|
|
|
|
if simultaneous_ip is None:
|
|
|
|
# No ratelimit for this token if null
|
|
|
|
simultaneous_ip = 999999999
|
2023-09-20 21:19:26 -06:00
|
|
|
finally:
|
|
|
|
cursor.close()
|
2023-09-25 00:55:20 -06:00
|
|
|
return priority, simultaneous_ip
|
2023-08-30 18:53:26 -06:00
|
|
|
|
2023-09-14 17:38:20 -06:00
|
|
|
def get_parameters(self):
|
2023-09-12 16:40:09 -06:00
|
|
|
if self.request_json_body.get('max_tokens'):
|
|
|
|
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
|
2023-09-14 17:38:20 -06:00
|
|
|
parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
|
|
|
|
return parameters, parameters_invalid_msg
|
2023-09-12 01:04:11 -06:00
|
|
|
|
2023-09-24 13:02:30 -06:00
|
|
|
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
|
|
|
|
"""
|
|
|
|
This needs to be called at the start of the subclass handle_request() method.
|
|
|
|
:param prompt:
|
|
|
|
:param do_log:
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
invalid_request_err_msgs = []
|
|
|
|
|
|
|
|
self.parameters, parameters_invalid_msg = self.get_parameters() # Parameters will be None if invalid.
|
|
|
|
if self.parameters and not parameters_invalid_msg:
|
|
|
|
# Backends shouldn't check max_new_tokens, but rather things specific to their backend.
|
|
|
|
# Let the RequestHandler do the generic checks.
|
|
|
|
if self.parameters.get('max_new_tokens', 0) > opts.max_new_tokens:
|
|
|
|
invalid_request_err_msgs.append(f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}')
|
|
|
|
|
|
|
|
if prompt:
|
|
|
|
prompt_valid, invalid_prompt_err_msg = self.backend.validate_prompt(prompt)
|
|
|
|
if not prompt_valid:
|
|
|
|
invalid_request_err_msgs.append(invalid_prompt_err_msg)
|
|
|
|
|
|
|
|
request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters, prompt, self.request)
|
|
|
|
if not request_valid:
|
|
|
|
invalid_request_err_msgs.append(invalid_request_err_msg)
|
|
|
|
else:
|
|
|
|
invalid_request_err_msgs.append(parameters_invalid_msg)
|
|
|
|
|
|
|
|
if len(invalid_request_err_msgs):
|
|
|
|
if len(invalid_request_err_msgs) > 1:
|
|
|
|
# Format multiple error messages each on a new line.
|
|
|
|
e = [f'\n{x}.' for x in invalid_request_err_msgs]
|
|
|
|
combined_error_message = '\n'.join(e)
|
|
|
|
else:
|
|
|
|
# Otherwise, just grab the first and only one.
|
|
|
|
combined_error_message = invalid_request_err_msgs[0] + '.'
|
2023-09-27 14:48:47 -06:00
|
|
|
backend_response = self.handle_error(combined_error_message, 'Validation Error')
|
2023-09-27 14:36:49 -06:00
|
|
|
|
2023-09-24 13:02:30 -06:00
|
|
|
if do_log:
|
2023-09-29 00:09:44 -06:00
|
|
|
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.cluster_backend, is_error=True)
|
2023-09-26 23:59:22 -06:00
|
|
|
return False, backend_response
|
2023-09-14 14:05:50 -06:00
|
|
|
return True, (None, 0)
|
|
|
|
|
|
|
|
def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]:
|
2023-09-14 18:31:13 -06:00
|
|
|
prompt = llm_request['prompt']
|
2023-09-14 14:05:50 -06:00
|
|
|
if not self.is_client_ratelimited():
|
2023-09-24 13:02:30 -06:00
|
|
|
# Validate again before submission since the backend handler may have changed something.
|
|
|
|
# Also, this is the first time we validate the prompt.
|
|
|
|
request_valid, invalid_response = self.validate_request(prompt, do_log=True)
|
|
|
|
if not request_valid:
|
|
|
|
return (False, None, None, 0), invalid_response
|
2023-09-29 00:09:44 -06:00
|
|
|
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters, self.cluster_backend), self.token_priority)
|
2023-09-14 14:05:50 -06:00
|
|
|
else:
|
|
|
|
event = None
|
|
|
|
|
|
|
|
if not event:
|
|
|
|
return (False, None, None, 0), self.handle_ratelimited()
|
|
|
|
|
2023-09-14 17:38:20 -06:00
|
|
|
success, response, error_msg = event.wait()
|
2023-09-14 14:05:50 -06:00
|
|
|
end_time = time.time()
|
|
|
|
elapsed_time = end_time - self.start_time
|
|
|
|
|
|
|
|
if response:
|
|
|
|
try:
|
|
|
|
# Be extra careful when getting attributes from the response object
|
|
|
|
response_status_code = response.status_code
|
|
|
|
except:
|
|
|
|
response_status_code = 0
|
|
|
|
else:
|
|
|
|
response_status_code = None
|
|
|
|
|
|
|
|
# ===============================================
|
|
|
|
|
|
|
|
# We encountered an error
|
|
|
|
if not success or not response or error_msg:
|
|
|
|
if not error_msg or error_msg == '':
|
|
|
|
error_msg = 'Unknown error.'
|
|
|
|
else:
|
|
|
|
error_msg = error_msg.strip('.') + '.'
|
2023-09-27 14:48:47 -06:00
|
|
|
backend_response = self.handle_error(error_msg)
|
2023-09-29 00:09:44 -06:00
|
|
|
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.cluster_backend, is_error=True)
|
2023-09-26 23:59:22 -06:00
|
|
|
return (False, None, None, 0), backend_response
|
2023-09-14 14:05:50 -06:00
|
|
|
|
|
|
|
# ===============================================
|
|
|
|
|
|
|
|
response_valid_json, response_json_body = validate_json(response)
|
2023-09-14 14:26:25 -06:00
|
|
|
return_json_err = False
|
2023-09-14 14:05:50 -06:00
|
|
|
|
|
|
|
# The backend didn't send valid JSON
|
|
|
|
if not response_valid_json:
|
2023-09-14 14:26:25 -06:00
|
|
|
return_json_err = True
|
|
|
|
|
|
|
|
# Make sure the backend didn't crap out.
|
|
|
|
results = response_json_body.get('results', [])
|
|
|
|
if len(results) and not results[0].get('text'):
|
|
|
|
return_json_err = True
|
|
|
|
|
|
|
|
if return_json_err:
|
2023-09-14 14:05:50 -06:00
|
|
|
error_msg = 'The backend did not return valid JSON.'
|
2023-09-27 14:48:47 -06:00
|
|
|
backend_response = self.handle_error(error_msg)
|
2023-09-29 00:09:44 -06:00
|
|
|
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.cluster_backend, is_error=True)
|
2023-09-26 23:59:22 -06:00
|
|
|
return (False, None, None, 0), backend_response
|
2023-09-14 14:05:50 -06:00
|
|
|
|
|
|
|
# ===============================================
|
|
|
|
|
|
|
|
self.used = True
|
|
|
|
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
|
|
|
|
|
|
|
def is_client_ratelimited(self) -> bool:
|
2023-09-28 03:44:30 -06:00
|
|
|
queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip))
|
|
|
|
x = redis.hget('processing_ips', self.client_ip)
|
|
|
|
if x:
|
|
|
|
processing_ip = int(x)
|
|
|
|
else:
|
|
|
|
processing_ip = 0
|
|
|
|
if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0:
|
2023-09-12 16:40:09 -06:00
|
|
|
return False
|
2023-08-30 18:53:26 -06:00
|
|
|
else:
|
2023-09-28 03:44:30 -06:00
|
|
|
print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} queued + processing.')
|
2023-09-12 16:40:09 -06:00
|
|
|
return True
|
|
|
|
|
2023-09-14 14:05:50 -06:00
|
|
|
def handle_request(self) -> Tuple[flask.Response, int]:
|
|
|
|
# Must include this in your child.
|
|
|
|
# if self.used:
|
|
|
|
# raise Exception('Can only use a RequestHandler object once.')
|
2023-09-12 16:40:09 -06:00
|
|
|
raise NotImplementedError
|
2023-08-30 18:53:26 -06:00
|
|
|
|
2023-09-28 01:34:15 -06:00
|
|
|
def handle_ratelimited(self, do_log: bool = True) -> Tuple[flask.Response, int]:
|
2023-09-12 16:40:09 -06:00
|
|
|
raise NotImplementedError
|
2023-09-12 01:04:11 -06:00
|
|
|
|
2023-09-27 14:48:47 -06:00
|
|
|
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
2023-09-14 17:38:20 -06:00
|
|
|
raise NotImplementedError
|
|
|
|
|
2023-08-30 18:53:26 -06:00
|
|
|
|
2023-09-29 00:09:44 -06:00
|
|
|
def get_backend_handler(mode):
|
|
|
|
if mode == 'oobabooga':
|
2023-09-12 16:40:09 -06:00
|
|
|
return OobaboogaBackend()
|
2023-09-29 00:09:44 -06:00
|
|
|
elif mode == 'vllm':
|
2023-09-12 16:40:09 -06:00
|
|
|
return VLLMBackend()
|
|
|
|
else:
|
|
|
|
raise Exception
|
2023-09-12 01:04:11 -06:00
|
|
|
|
2023-08-30 18:53:26 -06:00
|
|
|
|
2023-09-12 16:40:09 -06:00
|
|
|
def delete_dict_key(d: dict, k: Union[str, list]):
|
|
|
|
if isinstance(k, str):
|
|
|
|
if k in d.keys():
|
|
|
|
del d[k]
|
|
|
|
elif isinstance(k, list):
|
|
|
|
for item in k:
|
|
|
|
if item in d.keys():
|
|
|
|
del d[item]
|
|
|
|
else:
|
|
|
|
raise ValueError
|
|
|
|
return d
|
2023-09-17 18:55:36 -06:00
|
|
|
|
|
|
|
|
|
|
|
def before_request():
|
2023-09-23 23:24:08 -06:00
|
|
|
auto_set_base_client_api(request)
|
2023-09-17 18:55:36 -06:00
|
|
|
if request.endpoint != 'v1.get_stats':
|
|
|
|
response = require_api_key()
|
|
|
|
if response is not None:
|
|
|
|
return response
|