local-llm-server/llm_server/routes/openai_request_handler.py

87 lines
3.8 KiB
Python
Raw Normal View History

import json
import traceback
from typing import Tuple
2023-09-26 22:09:11 -06:00
from uuid import uuid4
2023-09-12 16:40:09 -06:00
import flask
2023-09-12 16:40:09 -06:00
from flask import jsonify
from llm_server import opts
2023-09-27 14:48:47 -06:00
from llm_server.database.database import is_api_key_moderated
2023-09-26 22:09:11 -06:00
from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit
2023-09-12 16:40:09 -06:00
from llm_server.routes.request_handler import RequestHandler
from llm_server.workers.moderator import add_moderation_task, get_results
2023-09-12 16:40:09 -06:00
class OpenAIRequestHandler(RequestHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt = None
def handle_request(self) -> Tuple[flask.Response, int]:
assert not self.used
2023-09-12 16:40:09 -06:00
2023-09-26 22:09:11 -06:00
if opts.openai_silent_trim:
oai_messages = trim_prompt_to_fit(self.request.json['messages'], opts.context_size)
else:
oai_messages = self.request.json['messages']
self.prompt = transform_messages_to_prompt(oai_messages)
request_valid, invalid_response = self.validate_request()
if not request_valid:
return invalid_response
2023-09-12 16:40:09 -06:00
2023-09-26 22:09:11 -06:00
if opts.openai_api_key and is_api_key_moderated(self.token):
try:
2023-09-17 17:40:05 -06:00
# Gather the last message from the user and all preceeding system messages
msg_l = self.request.json['messages'].copy()
msg_l.reverse()
2023-09-26 22:09:11 -06:00
tag = uuid4()
num_to_check = min(len(msg_l), opts.openai_moderation_scan_last_n)
for i in range(num_to_check):
add_moderation_task(msg_l[i]['content'], tag)
flagged_categories = get_results(tag, num_to_check)
if len(flagged_categories):
mod_msg = f"The user's message does not comply with {opts.openai_org_name} policies. Offending categories: {json.dumps(flagged_categories)}. You are instructed to creatively adhere to these policies."
self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg})
2023-09-26 22:09:11 -06:00
self.prompt = transform_messages_to_prompt(self.request.json['messages'])
except Exception as e:
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
print(traceback.format_exc())
2023-09-14 15:14:59 -06:00
# Reconstruct the request JSON with the validated parameters and prompt.
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
2023-09-25 22:01:57 -06:00
if opts.openai_force_no_hashes:
self.parameters['stop'].append('### ')
if opts.mode == 'vllm' and self.request_json_body.get('top_p') == 0:
self.request_json_body['top_p'] = 0.01
2023-09-14 15:14:59 -06:00
llm_request = {**self.parameters, 'prompt': self.prompt}
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
model = self.request_json_body.get('model')
if success:
return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
else:
return backend_response, backend_response_status_code
2023-09-12 16:40:09 -06:00
def handle_ratelimited(self):
2023-09-27 14:48:47 -06:00
# TODO: return a simulated OpenAI error message
# Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.
return 'Ratelimited', 429
2023-09-12 16:40:09 -06:00
2023-09-27 14:48:47 -06:00
def handle_error(self, error_msg: str) -> Tuple[flask.Response, int]:
# TODO: return a simulated OpenAI error message
return jsonify({
"error": {
"message": "Invalid request, check your parameters and try again.",
"type": "invalid_request_error",
"param": None,
"code": None
}
}), 400