2023-08-30 18:53:26 -06:00
import sqlite3
import time
from typing import Union
from flask import jsonify
from llm_server import opts
from llm_server . database import log_prompt
from llm_server . llm . hf_textgen . hf_textgen_backend import HfTextgenLLMBackend
from llm_server . llm . oobabooga . ooba_backend import OobaboogaLLMBackend
2023-09-11 20:47:19 -06:00
from llm_server . llm . vllm . vllm_backend import VLLMBackend
2023-08-30 18:53:26 -06:00
from llm_server . routes . cache import redis
from llm_server . routes . helpers . client import format_sillytavern_err
from llm_server . routes . helpers . http import validate_json
from llm_server . routes . queue import priority_queue
from llm_server . routes . stats import SemaphoreCheckerThread
DEFAULT_PRIORITY = 9999
2023-09-11 20:47:19 -06:00
def delete_dict_key ( d : dict , k : Union [ str , list ] ) :
if isinstance ( k , str ) :
if k in d . keys ( ) :
del d [ k ]
elif isinstance ( k , list ) :
for item in k :
if item in d . keys ( ) :
del d [ item ]
else :
raise ValueError
return d
2023-08-30 18:53:26 -06:00
class OobaRequestHandler :
def __init__ ( self , incoming_request ) :
self . request_json_body = None
self . request = incoming_request
self . start_time = time . time ( )
self . client_ip = self . get_client_ip ( )
self . token = self . request . headers . get ( ' X-Api-Key ' )
self . parameters = self . get_parameters ( )
self . priority = self . get_priority ( )
self . backend = self . get_backend ( )
def validate_request ( self ) - > ( bool , Union [ str , None ] ) :
2023-09-11 20:47:19 -06:00
# TODO: move this to LLMBackend
if self . parameters . get ( ' max_new_tokens ' , 0 ) > opts . max_new_tokens or self . parameters . get ( ' max_tokens ' , 0 ) > opts . max_new_tokens :
2023-08-30 18:53:26 -06:00
return False , f ' `max_new_tokens` must be less than or equal to { opts . max_new_tokens } '
return True , None
def get_client_ip ( self ) :
if self . request . headers . get ( ' cf-connecting-ip ' ) :
return self . request . headers . get ( ' cf-connecting-ip ' )
elif self . request . headers . get ( ' x-forwarded-for ' ) :
return self . request . headers . get ( ' x-forwarded-for ' ) . split ( ' , ' ) [ 0 ]
else :
return self . request . remote_addr
def get_parameters ( self ) :
2023-09-11 20:47:19 -06:00
# TODO: make this a LLMBackend method
2023-08-30 18:53:26 -06:00
request_valid_json , self . request_json_body = validate_json ( self . request . data )
if not request_valid_json :
2023-08-30 20:19:23 -06:00
return jsonify ( { ' code ' : 400 , ' msg ' : ' Invalid JSON ' } ) , 400
2023-08-30 18:53:26 -06:00
parameters = self . request_json_body . copy ( )
2023-09-11 20:47:19 -06:00
if opts . mode in [ ' oobabooga ' , ' hf-textgen ' ] :
del parameters [ ' prompt ' ]
elif opts . mode == ' vllm ' :
parameters = delete_dict_key ( parameters , [ ' messages ' , ' model ' , ' stream ' , ' logit_bias ' ] )
else :
raise Exception
2023-08-30 18:53:26 -06:00
return parameters
def get_priority ( self ) :
if self . token :
conn = sqlite3 . connect ( opts . database_path )
cursor = conn . cursor ( )
cursor . execute ( " SELECT priority FROM token_auth WHERE token = ? " , ( self . token , ) )
result = cursor . fetchone ( )
conn . close ( )
if result :
return result [ 0 ]
return DEFAULT_PRIORITY
def get_backend ( self ) :
if opts . mode == ' oobabooga ' :
return OobaboogaLLMBackend ( )
elif opts . mode == ' hf-textgen ' :
return HfTextgenLLMBackend ( )
2023-09-11 20:47:19 -06:00
elif opts . mode == ' vllm ' :
return VLLMBackend ( )
2023-08-30 18:53:26 -06:00
else :
raise Exception
def handle_request ( self ) :
SemaphoreCheckerThread . recent_prompters [ self . client_ip ] = time . time ( )
# Fix bug on text-generation-inference
# https://github.com/huggingface/text-generation-inference/issues/929
if opts . mode == ' hf-textgen ' and self . parameters . get ( ' typical_p ' , 0 ) > 0.998 :
2023-08-31 09:31:16 -06:00
self . request_json_body [ ' typical_p ' ] = 0.998
2023-08-30 18:53:26 -06:00
request_valid , invalid_request_err_msg = self . validate_request ( )
params_valid , invalid_params_err_msg = self . backend . validate_params ( self . parameters )
if not request_valid or not params_valid :
error_messages = [ msg for valid , msg in [ ( request_valid , invalid_request_err_msg ) , ( params_valid , invalid_params_err_msg ) ] if not valid ]
combined_error_message = ' , ' . join ( error_messages )
err = format_sillytavern_err ( f ' Validation Error: { combined_error_message } . ' , ' error ' )
log_prompt ( self . client_ip , self . token , self . request_json_body . get ( ' prompt ' , ' ' ) , err , 0 , self . parameters , dict ( self . request . headers ) , 0 , is_error = True )
2023-09-11 20:47:19 -06:00
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
2023-08-30 18:53:26 -06:00
return jsonify ( {
' code ' : 400 ,
' msg ' : ' parameter validation error ' ,
' results ' : [ { ' text ' : err } ]
} ) , 200
queued_ip_count = redis . get_dict ( ' queued_ip_count ' ) . get ( self . client_ip , 0 ) + redis . get_dict ( ' processing_ips ' ) . get ( self . client_ip , 0 )
2023-09-11 20:47:19 -06:00
if queued_ip_count < opts . simultaneous_requests_per_ip or self . priority == 0 :
2023-08-30 18:53:26 -06:00
event = priority_queue . put ( ( self . request_json_body , self . client_ip , self . token , self . parameters ) , self . priority )
else :
# Client was rate limited
event = None
if not event :
return self . handle_ratelimited ( )
event . wait ( )
success , response , error_msg = event . data
end_time = time . time ( )
elapsed_time = end_time - self . start_time
return self . backend . handle_response ( success , response , error_msg , self . client_ip , self . token , self . request_json_body . get ( ' prompt ' , ' ' ) , elapsed_time , self . parameters , dict ( self . request . headers ) )
def handle_ratelimited ( self ) :
2023-09-11 20:47:19 -06:00
backend_response = format_sillytavern_err ( f ' Ratelimited: you are only allowed to have { opts . simultaneous_requests_per_ip } simultaneous requests at a time. Please complete your other requests before sending another. ' , ' error ' )
2023-08-30 18:53:26 -06:00
log_prompt ( self . client_ip , self . token , self . request_json_body . get ( ' prompt ' , ' ' ) , backend_response , None , self . parameters , dict ( self . request . headers ) , 429 , is_error = True )
return jsonify ( {
' results ' : [ { ' text ' : backend_response } ]
} ) , 200