133 lines
5.3 KiB
Python
133 lines
5.3 KiB
Python
|
import re
|
||
|
import time
|
||
|
from uuid import uuid4
|
||
|
|
||
|
import tiktoken
|
||
|
from flask import jsonify
|
||
|
|
||
|
from llm_server import opts
|
||
|
from llm_server.database import log_prompt
|
||
|
from llm_server.routes.helpers.client import format_sillytavern_err
|
||
|
from llm_server.routes.helpers.http import validate_json
|
||
|
from llm_server.routes.queue import priority_queue
|
||
|
from llm_server.routes.request_handler import RequestHandler
|
||
|
|
||
|
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||
|
|
||
|
|
||
|
class OpenAIRequestHandler(RequestHandler):
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
super().__init__(*args, **kwargs)
|
||
|
self.prompt = None
|
||
|
|
||
|
def handle_request(self):
|
||
|
if self.used:
|
||
|
raise Exception
|
||
|
|
||
|
request_valid_json, self.request_json_body = validate_json(self.request)
|
||
|
self.prompt = self.transform_messages_to_prompt()
|
||
|
|
||
|
if not request_valid_json or not self.prompt:
|
||
|
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||
|
|
||
|
params_valid, request_valid = self.validate_request()
|
||
|
if not request_valid[0] or not params_valid[0]:
|
||
|
error_messages = [msg for valid, msg in [request_valid, params_valid] if not valid and msg]
|
||
|
combined_error_message = ', '.join(error_messages)
|
||
|
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
||
|
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)
|
||
|
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
|
||
|
return jsonify({
|
||
|
'code': 400,
|
||
|
'msg': 'parameter validation error',
|
||
|
'results': [{'text': err}]
|
||
|
}), 200
|
||
|
|
||
|
# Reconstruct the request JSON with the validated parameters and prompt.
|
||
|
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
|
||
|
llm_request = {**self.parameters, 'prompt': self.prompt}
|
||
|
|
||
|
if not self.is_client_ratelimited():
|
||
|
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
|
||
|
else:
|
||
|
event = None
|
||
|
|
||
|
if not event:
|
||
|
return self.handle_ratelimited()
|
||
|
|
||
|
event.wait()
|
||
|
success, backend_response, error_msg = event.data
|
||
|
|
||
|
end_time = time.time()
|
||
|
elapsed_time = end_time - self.start_time
|
||
|
|
||
|
self.used = True
|
||
|
response, response_status_code = self.backend.handle_response(success, backend_response, error_msg, self.client_ip, self.token, self.prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
||
|
return build_openai_response(self.prompt, response.json['results'][0]['text']), 200
|
||
|
|
||
|
def handle_ratelimited(self):
|
||
|
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
|
||
|
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, is_error=True)
|
||
|
return build_openai_response(self.prompt, backend_response), 200
|
||
|
|
||
|
def transform_messages_to_prompt(self):
|
||
|
try:
|
||
|
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
|
||
|
for msg in self.request.json['messages']:
|
||
|
if not msg.get('content') or not msg.get('role'):
|
||
|
return False
|
||
|
if msg['role'] == 'system':
|
||
|
prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
|
||
|
elif msg['role'] == 'user':
|
||
|
prompt += f'### USER: {msg["content"]}\n\n'
|
||
|
elif msg['role'] == 'assistant':
|
||
|
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
|
||
|
else:
|
||
|
return False
|
||
|
except:
|
||
|
return False
|
||
|
|
||
|
prompt = prompt.strip(' ').strip('\n').strip('\n\n') # TODO: this is really lazy
|
||
|
prompt += '\n\n### RESPONSE: '
|
||
|
return prompt
|
||
|
|
||
|
|
||
|
def build_openai_response(prompt, response):
|
||
|
# Seperate the user's prompt from the context
|
||
|
x = prompt.split('### USER:')
|
||
|
if len(x) > 1:
|
||
|
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
|
||
|
|
||
|
# Make sure the bot doesn't put any other instructions in its response
|
||
|
y = response.split('\n### ')
|
||
|
if len(x) > 1:
|
||
|
response = re.sub(r'\n$', '', y[0].strip(' '))
|
||
|
|
||
|
prompt_tokens = len(tokenizer.encode(prompt))
|
||
|
response_tokens = len(tokenizer.encode(response))
|
||
|
return jsonify({
|
||
|
"id": f"chatcmpl-{uuid4()}",
|
||
|
"object": "chat.completion",
|
||
|
"created": int(time.time()),
|
||
|
"model": opts.running_model,
|
||
|
"choices": [{
|
||
|
"index": 0,
|
||
|
"message": {
|
||
|
"role": "assistant",
|
||
|
"content": response,
|
||
|
},
|
||
|
"finish_reason": "stop"
|
||
|
}],
|
||
|
"usage": {
|
||
|
"prompt_tokens": prompt_tokens,
|
||
|
"completion_tokens": response_tokens,
|
||
|
"total_tokens": prompt_tokens + response_tokens
|
||
|
}
|
||
|
})
|
||
|
|
||
|
# def transform_prompt_to_text(prompt: list):
|
||
|
# text = ''
|
||
|
# for item in prompt:
|
||
|
# text += item['content'] + '\n'
|
||
|
# return text.strip('\n')
|