115 lines
5.1 KiB
Python
115 lines
5.1 KiB
Python
import sqlite3
|
|
import time
|
|
from typing import Union
|
|
|
|
from flask import jsonify
|
|
from requests.exceptions import InvalidJSONError
|
|
|
|
from llm_server import opts
|
|
from llm_server.database import log_prompt
|
|
from llm_server.llm.hf_textgen.hf_textgen_backend import HfTextgenLLMBackend
|
|
from llm_server.llm.oobabooga.ooba_backend import OobaboogaLLMBackend
|
|
from llm_server.routes.cache import redis
|
|
from llm_server.routes.helpers.client import format_sillytavern_err
|
|
from llm_server.routes.helpers.http import validate_json
|
|
from llm_server.routes.queue import priority_queue
|
|
from llm_server.routes.stats import SemaphoreCheckerThread
|
|
|
|
DEFAULT_PRIORITY = 9999
|
|
|
|
|
|
class OobaRequestHandler:
|
|
def __init__(self, incoming_request):
|
|
self.request_json_body = None
|
|
self.request = incoming_request
|
|
self.start_time = time.time()
|
|
self.client_ip = self.get_client_ip()
|
|
self.token = self.request.headers.get('X-Api-Key')
|
|
self.parameters = self.get_parameters()
|
|
self.priority = self.get_priority()
|
|
self.backend = self.get_backend()
|
|
|
|
def validate_request(self) -> (bool, Union[str, None]):
|
|
if self.parameters.get('max_new_tokens', 0) > opts.max_new_tokens:
|
|
return False, f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}'
|
|
return True, None
|
|
|
|
def get_client_ip(self):
|
|
if self.request.headers.get('cf-connecting-ip'):
|
|
return self.request.headers.get('cf-connecting-ip')
|
|
elif self.request.headers.get('x-forwarded-for'):
|
|
return self.request.headers.get('x-forwarded-for').split(',')[0]
|
|
else:
|
|
return self.request.remote_addr
|
|
|
|
def get_parameters(self):
|
|
request_valid_json, self.request_json_body = validate_json(self.request.data)
|
|
if not request_valid_json:
|
|
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
|
parameters = self.request_json_body.copy()
|
|
del parameters['prompt']
|
|
return parameters
|
|
|
|
def get_priority(self):
|
|
if self.token:
|
|
conn = sqlite3.connect(opts.database_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (self.token,))
|
|
result = cursor.fetchone()
|
|
conn.close()
|
|
if result:
|
|
return result[0]
|
|
return DEFAULT_PRIORITY
|
|
|
|
def get_backend(self):
|
|
if opts.mode == 'oobabooga':
|
|
return OobaboogaLLMBackend()
|
|
elif opts.mode == 'hf-textgen':
|
|
return HfTextgenLLMBackend()
|
|
else:
|
|
raise Exception
|
|
|
|
def handle_request(self):
|
|
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
|
|
|
|
# Fix bug on text-generation-inference
|
|
# https://github.com/huggingface/text-generation-inference/issues/929
|
|
if opts.mode == 'hf-textgen' and self.parameters.get('typical_p', 0) > 0.998:
|
|
self.parameters['typical_p'] = 0.998
|
|
|
|
request_valid, invalid_request_err_msg = self.validate_request()
|
|
params_valid, invalid_params_err_msg = self.backend.validate_params(self.parameters)
|
|
if not request_valid or not params_valid:
|
|
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, invalid_params_err_msg)] if not valid]
|
|
combined_error_message = ', '.join(error_messages)
|
|
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
|
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)
|
|
return jsonify({
|
|
'code': 400,
|
|
'msg': 'parameter validation error',
|
|
'results': [{'text': err}]
|
|
}), 200
|
|
|
|
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
|
|
if queued_ip_count < opts.ip_in_queue_max or self.priority == 0:
|
|
event = priority_queue.put((self.request_json_body, self.client_ip, self.token, self.parameters), self.priority)
|
|
else:
|
|
# Client was rate limited
|
|
event = None
|
|
|
|
if not event:
|
|
return self.handle_ratelimited()
|
|
event.wait()
|
|
success, response, error_msg = event.data
|
|
|
|
end_time = time.time()
|
|
elapsed_time = end_time - self.start_time
|
|
return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, self.request_json_body.get('prompt', ''), elapsed_time, self.parameters, dict(self.request.headers))
|
|
|
|
def handle_ratelimited(self):
|
|
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.ip_in_queue_max} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
|
|
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, is_error=True)
|
|
return jsonify({
|
|
'results': [{'text': backend_response}]
|
|
}), 200
|