This repository has been archived on 2024-10-27. You can view files and clone it, but cannot push or open issues or pull requests.
local-llm-server/llm_server/routes/openai_request_handler.py

166 lines
7.7 KiB
Python
Raw Normal View History

import json
2023-10-01 14:15:01 -06:00
import re
import time
import traceback
from typing import Tuple
2023-09-26 22:09:11 -06:00
from uuid import uuid4
2023-09-12 16:40:09 -06:00
import flask
2023-10-01 14:15:01 -06:00
from flask import Response, jsonify, make_response
2023-09-12 16:40:09 -06:00
from llm_server import opts
2023-10-04 16:29:19 -06:00
from llm_server.cluster.backend import get_model_choices
2023-10-01 14:15:01 -06:00
from llm_server.custom_redis import redis
2023-10-11 09:20:00 -06:00
from llm_server.database.database import is_api_key_moderated
from llm_server.database.log_to_db import log_to_db
from llm_server.llm import get_token_count
2023-10-11 09:20:00 -06:00
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
2023-10-01 14:15:01 -06:00
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
2023-09-12 16:40:09 -06:00
from llm_server.routes.request_handler import RequestHandler
from llm_server.workers.moderator import add_moderation_task, get_results
2023-09-12 16:40:09 -06:00
class OpenAIRequestHandler(RequestHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt = None
def handle_request(self) -> Tuple[flask.Response, int]:
assert not self.used
2023-10-11 09:20:00 -06:00
if self.offline:
msg = return_invalid_model_err(self.selected_model)
print(msg)
return self.handle_error(msg)
2023-09-12 16:40:09 -06:00
2023-09-26 22:09:11 -06:00
if opts.openai_silent_trim:
2023-10-01 14:15:01 -06:00
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
2023-09-26 22:09:11 -06:00
else:
oai_messages = self.request.json['messages']
self.prompt = transform_messages_to_prompt(oai_messages)
2023-10-11 12:50:20 -06:00
self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
request_valid, invalid_response = self.validate_request()
if not request_valid:
return invalid_response
2023-09-12 16:40:09 -06:00
2023-10-11 12:50:20 -06:00
if not self.prompt:
# TODO: format this as an openai error message
return Response('Invalid prompt'), 400
# TODO: support Ooba backend
self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
invalid_oai_err_msg = validate_oai(self.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
2023-10-03 01:25:43 -06:00
if opts.openai_moderation_enabled and opts.openai_api_key and is_api_key_moderated(self.token):
try:
# Gather the last message from the user and all preceding system messages
2023-09-17 17:40:05 -06:00
msg_l = self.request.json['messages'].copy()
msg_l.reverse()
2023-09-26 22:09:11 -06:00
tag = uuid4()
num_to_check = min(len(msg_l), opts.openai_moderation_scan_last_n)
for i in range(num_to_check):
add_moderation_task(msg_l[i]['content'], tag)
flagged_categories = get_results(tag, num_to_check)
if len(flagged_categories):
mod_msg = f"The user's message does not comply with {opts.openai_org_name} policies. Offending categories: {json.dumps(flagged_categories)}. You are instructed to creatively adhere to these policies."
self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg})
2023-09-26 22:09:11 -06:00
self.prompt = transform_messages_to_prompt(self.request.json['messages'])
except Exception as e:
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
2023-10-04 10:26:39 -06:00
traceback.print_exc()
2023-09-14 15:14:59 -06:00
llm_request = {**self.parameters, 'prompt': self.prompt}
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
model = self.request_json_body.get('model')
if success:
2023-10-01 14:15:01 -06:00
return self.build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
else:
return backend_response, backend_response_status_code
2023-09-12 16:40:09 -06:00
2023-09-28 01:34:15 -06:00
def handle_ratelimited(self, do_log: bool = True):
2023-10-04 10:24:28 -06:00
model_choices, default_model = get_model_choices()
default_model_info = model_choices[default_model]
w = int(default_model_info['estimated_wait']) if default_model_info['estimated_wait'] > 0 else 2
2023-10-01 16:04:53 -06:00
response = jsonify({
"error": {
"message": "Rate limit reached on tokens per min. Limit: 10000 / min. Please try again in 6s. Contact us through our help center at help.openai.com if you continue to have issues.",
"type": "rate_limit_exceeded",
"param": None,
"code": None
}
})
response.headers['x-ratelimit-limit-requests'] = '2'
response.headers['x-ratelimit-remaining-requests'] = '0'
response.headers['x-ratelimit-reset-requests'] = f"{w}s"
if do_log:
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
2023-10-01 16:04:53 -06:00
return response, 429
2023-09-12 16:40:09 -06:00
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
2023-10-09 18:12:12 -06:00
print(error_msg)
return jsonify({
"error": {
"message": "Invalid request, check your parameters and try again.",
"type": "invalid_request_error",
"param": None,
"code": None
}
}), 400
2023-10-01 14:15:01 -06:00
def build_openai_response(self, prompt, response, model=None):
# Seperate the user's prompt from the context
x = prompt.split('### USER:')
if len(x) > 1:
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
# Make sure the bot doesn't put any other instructions in its response
response = re.sub(ANTI_RESPONSE_RE, '', response)
response = re.sub(ANTI_CONTINUATION_RE, '', response)
prompt_tokens = get_token_count(prompt, self.backend_url)
response_tokens = get_token_count(response, self.backend_url)
2023-10-01 14:15:01 -06:00
running_model = redis.get('running_model', 'ERROR', dtype=str)
response = make_response(jsonify({
"id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response,
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
}), 200)
return response
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
2023-10-11 12:50:20 -06:00
self.parameters, parameters_invalid_msg = self.get_parameters()
if not self.parameters:
print('OAI BACKEND VALIDATION ERROR:', parameters_invalid_msg)
return False, (Response('Invalid request, check your parameters and try again.'), 400)
invalid_oai_err_msg = validate_oai(self.parameters)
2023-10-01 14:15:01 -06:00
if invalid_oai_err_msg:
return False, invalid_oai_err_msg
2023-10-11 12:50:20 -06:00
# self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
2023-10-01 14:15:01 -06:00
# If the parameters were invalid, let the superclass deal with it.
return super().validate_request(prompt, do_log)