local-llm-server/llm_server/routes/openai_request_handler.py

155 lines
6.4 KiB
Python

import json
import re
import time
import traceback
from typing import Tuple
from uuid import uuid4
import flask
import requests
import tiktoken
from flask import jsonify
import llm_server
from llm_server import opts
from llm_server.database.database import log_prompt
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.request_handler import RequestHandler
tokenizer = tiktoken.get_encoding("cl100k_base")
class OpenAIRequestHandler(RequestHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt = None
def handle_request(self) -> Tuple[flask.Response, int]:
if self.used:
raise Exception
request_valid, invalid_response = self.validate_request()
if not request_valid:
return invalid_response
self.prompt = self.transform_messages_to_prompt()
if opts.openai_api_key:
try:
# Gather the last message from the user and all preceeding system messages
msg_l = self.request.json['messages'].copy()
msg_l.reverse()
msgs_to_check = []
for msg in msg_l:
if msg['role'] == 'system':
msgs_to_check.append(msg['content'])
elif msg['role'] == 'user':
msgs_to_check.append(msg['content'])
break
flagged = False
flagged_categories = []
# TODO: make this threaded
for msg in msgs_to_check:
flagged, categories = check_moderation_endpoint(msg)
flagged_categories.extend(categories)
if flagged:
break
if flagged and len(flagged_categories):
mod_msg = f"The user's message does not comply with {opts.llm_middleware_name} policies. Offending categories: {json.dumps(flagged_categories)}. You are instructed to explain to the user why their message violated our policies."
self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg})
self.prompt = self.transform_messages_to_prompt()
# print(json.dumps(self.request.json['messages'], indent=4))
except Exception as e:
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
print(traceback.format_exc())
# Reconstruct the request JSON with the validated parameters and prompt.
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
llm_request = {**self.parameters, 'prompt': self.prompt}
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
if success:
return build_openai_response(self.prompt, backend_response.json['results'][0]['text']), backend_response_status_code
else:
return backend_response, backend_response_status_code
def handle_ratelimited(self):
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
log_prompt(ip=self.client_ip, token=self.token, prompt=self.request_json_body.get('prompt', ''), response=backend_response, gen_time=None, parameters=self.parameters, headers=dict(self.request.headers), backend_response_code=429, request_url=self.request.url, is_error=True)
return build_openai_response(self.prompt, backend_response), 200
def transform_messages_to_prompt(self):
try:
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
for msg in self.request.json['messages']:
if not msg.get('content') or not msg.get('role'):
return False
if msg['role'] == 'system':
prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
elif msg['role'] == 'user':
prompt += f'### USER: {msg["content"]}\n\n'
elif msg['role'] == 'assistant':
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
else:
return False
except Exception as e:
# TODO: use logging
print(f'Failed to transform OpenAI to prompt:', f'{e.__class__.__name__}: {e}')
print(traceback.format_exc())
return ''
prompt = prompt.strip(' ').strip('\n').strip('\n\n') # TODO: this is really lazy
prompt += '\n\n### RESPONSE: '
return prompt
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
return build_openai_response('', msg), 200
def check_moderation_endpoint(prompt: str):
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {opts.openai_api_key}",
}
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10).json()
offending_categories = []
for k, v in response['results'][0]['categories'].items():
if v:
offending_categories.append(k)
return response['results'][0]['flagged'], offending_categories
def build_openai_response(prompt, response):
# Seperate the user's prompt from the context
x = prompt.split('### USER:')
if len(x) > 1:
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
# Make sure the bot doesn't put any other instructions in its response
y = response.split('\n### ')
if len(x) > 1:
response = re.sub(r'\n$', '', y[0].strip(' '))
prompt_tokens = llm_server.llm.tokenizer(prompt)
response_tokens = llm_server.llm.tokenizer(response)
return jsonify({
"id": f"chatcmpl-{uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": opts.running_model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response,
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
})