2023-08-30 18:53:26 -06:00
import sqlite3
import time
from typing import Union
from flask import jsonify
from llm_server import opts
from llm_server . database import log_prompt
from llm_server . llm . hf_textgen . hf_textgen_backend import HfTextgenLLMBackend
from llm_server . llm . oobabooga . ooba_backend import OobaboogaLLMBackend
2023-09-11 20:47:19 -06:00
from llm_server . llm . vllm . vllm_backend import VLLMBackend
2023-08-30 18:53:26 -06:00
from llm_server . routes . cache import redis
from llm_server . routes . helpers . client import format_sillytavern_err
from llm_server . routes . helpers . http import validate_json
from llm_server . routes . queue import priority_queue
from llm_server . routes . stats import SemaphoreCheckerThread
DEFAULT_PRIORITY = 9999
2023-09-11 20:47:19 -06:00
def delete_dict_key ( d : dict , k : Union [ str , list ] ) :
if isinstance ( k , str ) :
if k in d . keys ( ) :
del d [ k ]
elif isinstance ( k , list ) :
for item in k :
if item in d . keys ( ) :
del d [ item ]
else :
raise ValueError
return d
2023-08-30 18:53:26 -06:00
class OobaRequestHandler :
def __init__ ( self , incoming_request ) :
self . request_json_body = None
self . request = incoming_request
self . start_time = time . time ( )
self . client_ip = self . get_client_ip ( )
self . token = self . request . headers . get ( ' X-Api-Key ' )
self . priority = self . get_priority ( )
self . backend = self . get_backend ( )
2023-09-12 01:04:11 -06:00
self . parameters = self . parameters_invalid_msg = None
2023-08-30 18:53:26 -06:00
def validate_request ( self ) - > ( bool , Union [ str , None ] ) :
2023-09-11 20:47:19 -06:00
# TODO: move this to LLMBackend
if self . parameters . get ( ' max_new_tokens ' , 0 ) > opts . max_new_tokens or self . parameters . get ( ' max_tokens ' , 0 ) > opts . max_new_tokens :
2023-08-30 18:53:26 -06:00
return False , f ' `max_new_tokens` must be less than or equal to { opts . max_new_tokens } '
return True , None
def get_client_ip ( self ) :
if self . request . headers . get ( ' cf-connecting-ip ' ) :
return self . request . headers . get ( ' cf-connecting-ip ' )
elif self . request . headers . get ( ' x-forwarded-for ' ) :
return self . request . headers . get ( ' x-forwarded-for ' ) . split ( ' , ' ) [ 0 ]
else :
return self . request . remote_addr
2023-09-12 01:04:11 -06:00
# def get_parameters(self):
# # TODO: make this a LLMBackend method
# return self.backend.get_parameters()
2023-08-30 18:53:26 -06:00
def get_priority ( self ) :
if self . token :
conn = sqlite3 . connect ( opts . database_path )
cursor = conn . cursor ( )
cursor . execute ( " SELECT priority FROM token_auth WHERE token = ? " , ( self . token , ) )
result = cursor . fetchone ( )
conn . close ( )
if result :
return result [ 0 ]
return DEFAULT_PRIORITY
def get_backend ( self ) :
if opts . mode == ' oobabooga ' :
return OobaboogaLLMBackend ( )
elif opts . mode == ' hf-textgen ' :
return HfTextgenLLMBackend ( )
2023-09-11 20:47:19 -06:00
elif opts . mode == ' vllm ' :
return VLLMBackend ( )
2023-08-30 18:53:26 -06:00
else :
raise Exception
2023-09-12 01:04:11 -06:00
def get_parameters ( self ) :
self . parameters , self . parameters_invalid_msg = self . backend . get_parameters ( self . request_json_body )
2023-08-30 18:53:26 -06:00
def handle_request ( self ) :
2023-09-12 01:04:11 -06:00
request_valid_json , self . request_json_body = validate_json ( self . request . data )
if not request_valid_json :
return jsonify ( { ' code ' : 400 , ' msg ' : ' Invalid JSON ' } ) , 400
2023-08-30 18:53:26 -06:00
2023-09-12 01:04:11 -06:00
self . get_parameters ( )
2023-08-30 18:53:26 -06:00
2023-09-12 01:04:11 -06:00
SemaphoreCheckerThread . recent_prompters [ self . client_ip ] = time . time ( )
2023-09-11 21:05:22 -06:00
2023-08-30 18:53:26 -06:00
request_valid , invalid_request_err_msg = self . validate_request ( )
2023-09-12 01:04:11 -06:00
if not self . parameters :
params_valid = False
else :
params_valid = True
2023-08-30 18:53:26 -06:00
if not request_valid or not params_valid :
2023-09-12 01:04:11 -06:00
error_messages = [ msg for valid , msg in [ ( request_valid , invalid_request_err_msg ) , ( params_valid , self . parameters_invalid_msg ) ] if not valid ]
2023-08-30 18:53:26 -06:00
combined_error_message = ' , ' . join ( error_messages )
err = format_sillytavern_err ( f ' Validation Error: { combined_error_message } . ' , ' error ' )
log_prompt ( self . client_ip , self . token , self . request_json_body . get ( ' prompt ' , ' ' ) , err , 0 , self . parameters , dict ( self . request . headers ) , 0 , is_error = True )
2023-09-11 20:47:19 -06:00
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
2023-08-30 18:53:26 -06:00
return jsonify ( {
' code ' : 400 ,
' msg ' : ' parameter validation error ' ,
' results ' : [ { ' text ' : err } ]
} ) , 200
2023-09-12 01:04:11 -06:00
# Reconstruct the request JSON with the validated parameters and prompt.
prompt = self . request_json_body . get ( ' prompt ' , ' ' )
llm_request = { * * self . parameters , ' prompt ' : prompt }
2023-08-30 18:53:26 -06:00
queued_ip_count = redis . get_dict ( ' queued_ip_count ' ) . get ( self . client_ip , 0 ) + redis . get_dict ( ' processing_ips ' ) . get ( self . client_ip , 0 )
2023-09-11 20:47:19 -06:00
if queued_ip_count < opts . simultaneous_requests_per_ip or self . priority == 0 :
2023-09-12 01:04:11 -06:00
event = priority_queue . put ( ( llm_request , self . client_ip , self . token , self . parameters ) , self . priority )
2023-08-30 18:53:26 -06:00
else :
# Client was rate limited
event = None
if not event :
return self . handle_ratelimited ( )
2023-09-12 01:04:11 -06:00
2023-08-30 18:53:26 -06:00
event . wait ( )
success , response , error_msg = event . data
end_time = time . time ( )
elapsed_time = end_time - self . start_time
2023-09-12 01:04:11 -06:00
return self . backend . handle_response ( success , response , error_msg , self . client_ip , self . token , prompt , elapsed_time , self . parameters , dict ( self . request . headers ) )
2023-08-30 18:53:26 -06:00
def handle_ratelimited ( self ) :
2023-09-11 20:47:19 -06:00
backend_response = format_sillytavern_err ( f ' Ratelimited: you are only allowed to have { opts . simultaneous_requests_per_ip } simultaneous requests at a time. Please complete your other requests before sending another. ' , ' error ' )
2023-08-30 18:53:26 -06:00
log_prompt ( self . client_ip , self . token , self . request_json_body . get ( ' prompt ' , ' ' ) , backend_response , None , self . parameters , dict ( self . request . headers ) , 429 , is_error = True )
return jsonify ( {
' results ' : [ { ' text ' : backend_response } ]
} ) , 200