2023-09-12 16:40:09 -06:00
import re
import time
from uuid import uuid4
import tiktoken
from flask import jsonify
from llm_server import opts
from llm_server . database import log_prompt
from llm_server . routes . helpers . client import format_sillytavern_err
from llm_server . routes . helpers . http import validate_json
from llm_server . routes . queue import priority_queue
from llm_server . routes . request_handler import RequestHandler
tokenizer = tiktoken . get_encoding ( " cl100k_base " )
class OpenAIRequestHandler ( RequestHandler ) :
def __init__ ( self , * args , * * kwargs ) :
super ( ) . __init__ ( * args , * * kwargs )
self . prompt = None
def handle_request ( self ) :
if self . used :
raise Exception
request_valid_json , self . request_json_body = validate_json ( self . request )
self . prompt = self . transform_messages_to_prompt ( )
if not request_valid_json or not self . prompt :
return jsonify ( { ' code ' : 400 , ' msg ' : ' Invalid JSON ' } ) , 400
params_valid , request_valid = self . validate_request ( )
if not request_valid [ 0 ] or not params_valid [ 0 ] :
error_messages = [ msg for valid , msg in [ request_valid , params_valid ] if not valid and msg ]
combined_error_message = ' , ' . join ( error_messages )
err = format_sillytavern_err ( f ' Validation Error: { combined_error_message } . ' , ' error ' )
2023-09-13 11:22:33 -06:00
log_prompt ( self . client_ip , self . token , self . request_json_body . get ( ' prompt ' , ' ' ) , err , 0 , self . parameters , dict ( self . request . headers ) , 0 , self . request . url , is_error = True )
2023-09-12 16:40:09 -06:00
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
return jsonify ( {
' code ' : 400 ,
' msg ' : ' parameter validation error ' ,
' results ' : [ { ' text ' : err } ]
} ) , 200
# Reconstruct the request JSON with the validated parameters and prompt.
self . parameters [ ' stop ' ] . extend ( [ ' \n ### INSTRUCTION ' , ' \n ### USER ' , ' \n ### ASSISTANT ' , ' \n ### RESPONSE ' ] )
llm_request = { * * self . parameters , ' prompt ' : self . prompt }
if not self . is_client_ratelimited ( ) :
event = priority_queue . put ( ( llm_request , self . client_ip , self . token , self . parameters ) , self . priority )
else :
event = None
if not event :
return self . handle_ratelimited ( )
event . wait ( )
success , backend_response , error_msg = event . data
end_time = time . time ( )
elapsed_time = end_time - self . start_time
self . used = True
2023-09-13 11:51:46 -06:00
response , response_status_code = self . backend . handle_response ( success = success , request = self . request , response = backend_response , error_msg = error_msg , client_ip = self . client_ip , token = self . token , prompt = self . prompt , elapsed_time = elapsed_time , parameters = self . parameters , headers = dict ( self . request . headers ) )
2023-09-12 16:40:09 -06:00
return build_openai_response ( self . prompt , response . json [ ' results ' ] [ 0 ] [ ' text ' ] ) , 200
def handle_ratelimited ( self ) :
backend_response = format_sillytavern_err ( f ' Ratelimited: you are only allowed to have { opts . simultaneous_requests_per_ip } simultaneous requests at a time. Please complete your other requests before sending another. ' , ' error ' )
2023-09-13 11:51:46 -06:00
log_prompt ( ip = self . client_ip , token = self . token , prompt = self . request_json_body . get ( ' prompt ' , ' ' ) , response = backend_response , gen_time = None , parameters = self . parameters , headers = dict ( self . request . headers ) , backend_response_code = 429 , request_url = self . request . url , is_error = True )
2023-09-12 16:40:09 -06:00
return build_openai_response ( self . prompt , backend_response ) , 200
def transform_messages_to_prompt ( self ) :
try :
prompt = f ' ### INSTRUCTION: { opts . openai_system_prompt } '
for msg in self . request . json [ ' messages ' ] :
if not msg . get ( ' content ' ) or not msg . get ( ' role ' ) :
return False
if msg [ ' role ' ] == ' system ' :
prompt + = f ' ### INSTRUCTION: { msg [ " content " ] } \n \n '
elif msg [ ' role ' ] == ' user ' :
prompt + = f ' ### USER: { msg [ " content " ] } \n \n '
elif msg [ ' role ' ] == ' assistant ' :
prompt + = f ' ### ASSISTANT: { msg [ " content " ] } \n \n '
else :
return False
except :
return False
prompt = prompt . strip ( ' ' ) . strip ( ' \n ' ) . strip ( ' \n \n ' ) # TODO: this is really lazy
prompt + = ' \n \n ### RESPONSE: '
return prompt
def build_openai_response ( prompt , response ) :
# Seperate the user's prompt from the context
x = prompt . split ( ' ### USER: ' )
if len ( x ) > 1 :
prompt = re . sub ( r ' \ n$ ' , ' ' , x [ - 1 ] . strip ( ' ' ) )
# Make sure the bot doesn't put any other instructions in its response
y = response . split ( ' \n ### ' )
if len ( x ) > 1 :
response = re . sub ( r ' \ n$ ' , ' ' , y [ 0 ] . strip ( ' ' ) )
prompt_tokens = len ( tokenizer . encode ( prompt ) )
response_tokens = len ( tokenizer . encode ( response ) )
return jsonify ( {
" id " : f " chatcmpl- { uuid4 ( ) } " ,
" object " : " chat.completion " ,
" created " : int ( time . time ( ) ) ,
" model " : opts . running_model ,
" choices " : [ {
" index " : 0 ,
" message " : {
" role " : " assistant " ,
" content " : response ,
} ,
" finish_reason " : " stop "
} ] ,
" usage " : {
" prompt_tokens " : prompt_tokens ,
" completion_tokens " : response_tokens ,
" total_tokens " : prompt_tokens + response_tokens
}
} )
# def transform_prompt_to_text(prompt: list):
# text = ''
# for item in prompt:
# text += item['content'] + '\n'
# return text.strip('\n')