From 9740df07c7185ff4f96346fbca044226160d3ea7 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Tue, 12 Sep 2023 16:40:09 -0600 Subject: [PATCH] add openai-compatible backend --- llm_server/config.py | 4 +- llm_server/llm/llm_backend.py | 5 + llm_server/llm/oobabooga/ooba_backend.py | 2 +- llm_server/llm/vllm/generate.py | 1 - llm_server/llm/vllm/vllm_backend.py | 52 ++++---- llm_server/opts.py | 3 + llm_server/routes/helpers/http.py | 25 ++-- llm_server/routes/ooba_request_handler.py | 64 +++++++++ llm_server/routes/openai/__init__.py | 32 +++++ llm_server/routes/openai/chat_completions.py | 24 ++++ llm_server/routes/openai/models.py | 57 ++++++++ llm_server/routes/openai_request_handler.py | 132 +++++++++++++++++++ llm_server/routes/request_handler.py | 122 ++++++----------- llm_server/routes/stats.py | 1 + llm_server/routes/v1/generate.py | 4 +- llm_server/routes/v1/generate_stats.py | 2 +- llm_server/routes/v1/generate_stream.py | 1 - llm_server/threads.py | 1 + server.py | 8 +- templates/home.html | 1 + 20 files changed, 412 insertions(+), 129 deletions(-) create mode 100644 llm_server/routes/ooba_request_handler.py create mode 100644 llm_server/routes/openai/__init__.py create mode 100644 llm_server/routes/openai/chat_completions.py create mode 100644 llm_server/routes/openai/models.py create mode 100644 llm_server/routes/openai_request_handler.py diff --git a/llm_server/config.py b/llm_server/config.py index ba64952..01c979a 100644 --- a/llm_server/config.py +++ b/llm_server/config.py @@ -16,7 +16,9 @@ config_default_vars = { 'simultaneous_requests_per_ip': 3, 'show_backend_info': True, 'max_new_tokens': 500, - 'manual_model_name': False + 'manual_model_name': False, + 'enable_openi_compatible_backend': True, + 'openai_system_prompt': """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nYou are the assistant and answer to the `### RESPONSE` prompt. Lines that start with `### ASSISTANT` were messages you sent previously.\nLines that start with `### USER` were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompts and follow the instructions given by the user.\n\n""", } config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name'] diff --git a/llm_server/llm/llm_backend.py b/llm_server/llm/llm_backend.py index cf27e67..fae0aba 100644 --- a/llm_server/llm/llm_backend.py +++ b/llm_server/llm/llm_backend.py @@ -2,6 +2,8 @@ from typing import Tuple, Union class LLMBackend: + default_params: dict + def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers): raise NotImplementedError @@ -19,3 +21,6 @@ class LLMBackend: :return: """ raise NotImplementedError + + def validate_request(self, parameters: dict) -> (bool, Union[str, None]): + raise NotImplementedError diff --git a/llm_server/llm/oobabooga/ooba_backend.py b/llm_server/llm/oobabooga/ooba_backend.py index ee3a7d6..4d45c36 100644 --- a/llm_server/llm/oobabooga/ooba_backend.py +++ b/llm_server/llm/oobabooga/ooba_backend.py @@ -8,7 +8,7 @@ from ...routes.helpers.client import format_sillytavern_err from ...routes.helpers.http import validate_json -class OobaboogaLLMBackend(LLMBackend): +class OobaboogaBackend(LLMBackend): def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers): backend_err = False response_valid_json, response_json_body = validate_json(response) diff --git a/llm_server/llm/vllm/generate.py b/llm_server/llm/vllm/generate.py index 2e1267c..b740679 100644 --- a/llm_server/llm/vllm/generate.py +++ b/llm_server/llm/vllm/generate.py @@ -36,7 +36,6 @@ def transform_to_text(json_request, api_response): data = json.loads(line[5:].strip()) except json.decoder.JSONDecodeError: break - print(data) if 'choices' in data: for choice in data['choices']: if 'delta' in choice and 'content' in choice['delta']: diff --git a/llm_server/llm/vllm/vllm_backend.py b/llm_server/llm/vllm/vllm_backend.py index 5a57de6..6828bf4 100644 --- a/llm_server/llm/vllm/vllm_backend.py +++ b/llm_server/llm/vllm/vllm_backend.py @@ -1,8 +1,9 @@ -from typing import Tuple +from typing import Tuple, Union from flask import jsonify from vllm import SamplingParams +from llm_server import opts from llm_server.database import log_prompt from llm_server.llm.llm_backend import LLMBackend from llm_server.routes.helpers.client import format_sillytavern_err @@ -14,7 +15,9 @@ from llm_server.routes.helpers.http import validate_json # TODO: https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/proxy/middleware/common.ts?ref_type=heads#L69 class VLLMBackend(LLMBackend): - def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers): + default_params = vars(SamplingParams()) + + def handle_response(self, success, response, error_msg, client_ip, token, prompt: str, elapsed_time, parameters, headers): response_valid_json, response_json_body = validate_json(response) backend_err = False try: @@ -50,18 +53,18 @@ class VLLMBackend(LLMBackend): }), 200 # def validate_params(self, params_dict: dict): - # default_params = SamplingParams() + # self.default_params = SamplingParams() # try: # sampling_params = SamplingParams( - # temperature=params_dict.get('temperature', default_params.temperature), - # top_p=params_dict.get('top_p', default_params.top_p), - # top_k=params_dict.get('top_k', default_params.top_k), + # temperature=params_dict.get('temperature', self.default_paramstemperature), + # top_p=params_dict.get('top_p', self.default_paramstop_p), + # top_k=params_dict.get('top_k', self.default_paramstop_k), # use_beam_search=True if params_dict['num_beams'] > 1 else False, - # length_penalty=params_dict.get('length_penalty', default_params.length_penalty), - # early_stopping=params_dict.get('early_stopping', default_params.early_stopping), - # stop=params_dict.get('stopping_strings', default_params.stop), + # length_penalty=params_dict.get('length_penalty', self.default_paramslength_penalty), + # early_stopping=params_dict.get('early_stopping', self.default_paramsearly_stopping), + # stop=params_dict.get('stopping_strings', self.default_paramsstop), # ignore_eos=params_dict.get('ban_eos_token', False), - # max_tokens=params_dict.get('max_new_tokens', default_params.max_tokens) + # max_tokens=params_dict.get('max_new_tokens', self.default_paramsmax_tokens) # ) # except ValueError as e: # print(e) @@ -79,30 +82,21 @@ class VLLMBackend(LLMBackend): # return False, e def get_parameters(self, parameters) -> Tuple[dict | None, str | None]: - default_params = SamplingParams() try: sampling_params = SamplingParams( - temperature=parameters.get('temperature', default_params.temperature), - top_p=parameters.get('top_p', default_params.top_p), - top_k=parameters.get('top_k', default_params.top_k), - use_beam_search=True if parameters['num_beams'] > 1 else False, - stop=parameters.get('stopping_strings', default_params.stop), + temperature=parameters.get('temperature', self.default_params['temperature']), + top_p=parameters.get('top_p', self.default_params['top_p']), + top_k=parameters.get('top_k', self.default_params['top_k']), + use_beam_search=True if parameters.get('num_beams', 0) > 1 else False, + stop=parameters.get('stopping_strings', self.default_params['stop']), ignore_eos=parameters.get('ban_eos_token', False), - max_tokens=parameters.get('max_new_tokens', default_params.max_tokens) + max_tokens=parameters.get('max_new_tokens', self.default_params['max_tokens']) ) except ValueError as e: return None, str(e).strip('.') return vars(sampling_params), None -# def transform_sampling_params(params: SamplingParams): -# return { -# 'temperature': params['temperature'], -# 'top_p': params['top_p'], -# 'top_k': params['top_k'], -# 'use_beam_search' = True if parameters['num_beams'] > 1 else False, -# length_penalty = parameters.get('length_penalty', default_params.length_penalty), -# early_stopping = parameters.get('early_stopping', default_params.early_stopping), -# stop = parameters.get('stopping_strings', default_params.stop), -# ignore_eos = parameters.get('ban_eos_token', False), -# max_tokens = parameters.get('max_new_tokens', default_params.max_tokens) -# } + def validate_request(self, parameters) -> (bool, Union[str, None]): + if parameters.get('max_new_tokens', 0) > opts.max_new_tokens: + return False, f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}' + return True, None diff --git a/llm_server/opts.py b/llm_server/opts.py index 8708b83..2674af6 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -23,3 +23,6 @@ netdata_root = None simultaneous_requests_per_ip = 3 show_backend_info = True manual_model_name = None +llm_middleware_name = '' +enable_openi_compatible_backend = True +openai_system_prompt = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nYou are the assistant and answer to the `### RESPONSE` prompt. Lines that start with `### ASSISTANT` were messages you sent previously.\nLines that start with `### USER` were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompts and follow the instructions given by the user.\n\n""" diff --git a/llm_server/routes/helpers/http.py b/llm_server/routes/helpers/http.py index 8784160..8cc7a02 100644 --- a/llm_server/routes/helpers/http.py +++ b/llm_server/routes/helpers/http.py @@ -1,11 +1,11 @@ import json +from functools import wraps from typing import Union -from flask import make_response -from requests import Response - +import flask +import requests +from flask import make_response, Request from flask import request, jsonify -from functools import wraps from llm_server import opts from llm_server.database import is_valid_api_key @@ -39,15 +39,18 @@ def require_api_key(): return jsonify({'code': 401, 'message': 'API key required'}), 401 -def validate_json(data: Union[str, Response]): - if isinstance(data, Response): - try: +def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response]): + try: + if isinstance(data, (Request, flask.Response)): + data = data.json + return True, data + elif isinstance(data, requests.models.Response): data = data.json() return True, data - except Exception as e: - return False, None + except Exception as e: + return False, e try: - j = json.loads(data) + j = json.loads(str(data)) return True, j except Exception as e: - return False, None + return False, e diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py new file mode 100644 index 0000000..0d30894 --- /dev/null +++ b/llm_server/routes/ooba_request_handler.py @@ -0,0 +1,64 @@ +import time + +from flask import jsonify + +from llm_server import opts +from llm_server.database import log_prompt +from llm_server.routes.helpers.client import format_sillytavern_err +from llm_server.routes.helpers.http import validate_json +from llm_server.routes.queue import priority_queue +from llm_server.routes.request_handler import RequestHandler + + +class OobaRequestHandler(RequestHandler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def handle_request(self): + if self.used: + raise Exception + + request_valid_json, self.request_json_body = validate_json(self.request) + if not request_valid_json: + return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 + + params_valid, request_valid = self.validate_request() + if not request_valid[0] or not params_valid[1]: + error_messages = [msg for valid, msg in [request_valid, params_valid] if not valid and msg] + combined_error_message = ', '.join(error_messages) + err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error') + log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True) + # TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types + return jsonify({ + 'code': 400, + 'msg': 'parameter validation error', + 'results': [{'text': err}] + }), 200 + + # Reconstruct the request JSON with the validated parameters and prompt. + prompt = self.request_json_body.get('prompt', '') + llm_request = {**self.parameters, 'prompt': prompt} + + if not self.is_client_ratelimited(): + event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority) + else: + event = None + + if not event: + return self.handle_ratelimited() + + event.wait() + success, response, error_msg = event.data + + end_time = time.time() + elapsed_time = end_time - self.start_time + + self.used = True + return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers)) + + def handle_ratelimited(self): + backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') + log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, is_error=True) + return jsonify({ + 'results': [{'text': backend_response}] + }), 200 diff --git a/llm_server/routes/openai/__init__.py b/llm_server/routes/openai/__init__.py new file mode 100644 index 0000000..3b21ecf --- /dev/null +++ b/llm_server/routes/openai/__init__.py @@ -0,0 +1,32 @@ +from flask import Blueprint, request + +from ..helpers.client import format_sillytavern_err +from ..helpers.http import require_api_key +from ..openai_request_handler import build_openai_response +from ..server_error import handle_server_error +from ... import opts + +openai_bp = Blueprint('openai/v1/', __name__) + + +@openai_bp.before_request +def before_request(): + if not opts.http_host: + opts.http_host = request.headers.get("Host") + if not opts.enable_openi_compatible_backend: + return build_openai_response('', format_sillytavern_err('The OpenAI-compatible backend is disabled.', 'Access Denied')), 401 + if not opts.base_client_api: + opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}' + if request.endpoint != 'v1.get_stats': + response = require_api_key() + if response is not None: + return response + + +@openai_bp.errorhandler(500) +def handle_error(e): + return handle_server_error(e) + + +from .models import openai_list_models +from .chat_completions import openai_chat_completions diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py new file mode 100644 index 0000000..0dfc888 --- /dev/null +++ b/llm_server/routes/openai/chat_completions.py @@ -0,0 +1,24 @@ +from flask import jsonify, request + +from . import openai_bp +from ..helpers.http import validate_json +from ..openai_request_handler import OpenAIRequestHandler + + +class FakeFlaskRequest(): + def __init__(self, *args, **kwargs): + self.data = kwargs.get('data') + self.headers = kwargs.get('headers') + self.json = kwargs.get('json') + self.remote_addr = kwargs.get('remote_addr') + + +@openai_bp.route('/chat/completions', methods=['POST']) +def openai_chat_completions(): + # TODO: make this work with oobabooga + request_valid_json, request_json_body = validate_json(request) + if not request_valid_json or not request_json_body.get('messages'): + return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 + else: + handler = OpenAIRequestHandler(request) + return handler.handle_request() diff --git a/llm_server/routes/openai/models.py b/llm_server/routes/openai/models.py new file mode 100644 index 0000000..9d6a223 --- /dev/null +++ b/llm_server/routes/openai/models.py @@ -0,0 +1,57 @@ +from flask import jsonify, request + +from . import openai_bp +from ..cache import cache, redis +from ..stats import server_start_time +from ... import opts +from ...llm.info import get_running_model + + +@openai_bp.route('/models', methods=['GET']) +def openai_list_models(): + cache_key = 'openai_model_cache::' + request.url + cached_response = cache.get(cache_key) + + if cached_response: + return cached_response + + model, error = get_running_model() + if not model: + response = jsonify({ + 'code': 502, + 'msg': 'failed to reach backend', + 'type': error.__class__.__name__ + }), 500 # return 500 so Cloudflare doesn't intercept us + else: + response = jsonify({ + "object": "list", + "data": [ + { + "id": opts.running_model, + "object": "model", + "created": int(server_start_time.timestamp()), + "owned_by": opts.llm_middleware_name, + "permission": [ + { + "id": opts.running_model, + "object": "model_permission", + "created": int(server_start_time.timestamp()), + "allow_create_engine": False, + "allow_sampling": False, + "allow_logprobs": False, + "allow_search_indices": False, + "allow_view": True, + "allow_fine_tuning": False, + "organization": "*", + "group": None, + "is_blocking": False + } + ], + "root": None, + "parent": None + } + ] + }), 200 + cache.set(cache_key, response, timeout=60) + + return response diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py new file mode 100644 index 0000000..399b2cd --- /dev/null +++ b/llm_server/routes/openai_request_handler.py @@ -0,0 +1,132 @@ +import re +import time +from uuid import uuid4 + +import tiktoken +from flask import jsonify + +from llm_server import opts +from llm_server.database import log_prompt +from llm_server.routes.helpers.client import format_sillytavern_err +from llm_server.routes.helpers.http import validate_json +from llm_server.routes.queue import priority_queue +from llm_server.routes.request_handler import RequestHandler + +tokenizer = tiktoken.get_encoding("cl100k_base") + + +class OpenAIRequestHandler(RequestHandler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.prompt = None + + def handle_request(self): + if self.used: + raise Exception + + request_valid_json, self.request_json_body = validate_json(self.request) + self.prompt = self.transform_messages_to_prompt() + + if not request_valid_json or not self.prompt: + return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 + + params_valid, request_valid = self.validate_request() + if not request_valid[0] or not params_valid[0]: + error_messages = [msg for valid, msg in [request_valid, params_valid] if not valid and msg] + combined_error_message = ', '.join(error_messages) + err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error') + log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True) + # TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types + return jsonify({ + 'code': 400, + 'msg': 'parameter validation error', + 'results': [{'text': err}] + }), 200 + + # Reconstruct the request JSON with the validated parameters and prompt. + self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE']) + llm_request = {**self.parameters, 'prompt': self.prompt} + + if not self.is_client_ratelimited(): + event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority) + else: + event = None + + if not event: + return self.handle_ratelimited() + + event.wait() + success, backend_response, error_msg = event.data + + end_time = time.time() + elapsed_time = end_time - self.start_time + + self.used = True + response, response_status_code = self.backend.handle_response(success, backend_response, error_msg, self.client_ip, self.token, self.prompt, elapsed_time, self.parameters, dict(self.request.headers)) + return build_openai_response(self.prompt, response.json['results'][0]['text']), 200 + + def handle_ratelimited(self): + backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') + log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, is_error=True) + return build_openai_response(self.prompt, backend_response), 200 + + def transform_messages_to_prompt(self): + try: + prompt = f'### INSTRUCTION: {opts.openai_system_prompt}' + for msg in self.request.json['messages']: + if not msg.get('content') or not msg.get('role'): + return False + if msg['role'] == 'system': + prompt += f'### INSTRUCTION: {msg["content"]}\n\n' + elif msg['role'] == 'user': + prompt += f'### USER: {msg["content"]}\n\n' + elif msg['role'] == 'assistant': + prompt += f'### ASSISTANT: {msg["content"]}\n\n' + else: + return False + except: + return False + + prompt = prompt.strip(' ').strip('\n').strip('\n\n') # TODO: this is really lazy + prompt += '\n\n### RESPONSE: ' + return prompt + + +def build_openai_response(prompt, response): + # Seperate the user's prompt from the context + x = prompt.split('### USER:') + if len(x) > 1: + prompt = re.sub(r'\n$', '', x[-1].strip(' ')) + + # Make sure the bot doesn't put any other instructions in its response + y = response.split('\n### ') + if len(x) > 1: + response = re.sub(r'\n$', '', y[0].strip(' ')) + + prompt_tokens = len(tokenizer.encode(prompt)) + response_tokens = len(tokenizer.encode(response)) + return jsonify({ + "id": f"chatcmpl-{uuid4()}", + "object": "chat.completion", + "created": int(time.time()), + "model": opts.running_model, + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": response, + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": response_tokens, + "total_tokens": prompt_tokens + response_tokens + } + }) + +# def transform_prompt_to_text(prompt: list): +# text = '' +# for item in prompt: +# text += item['content'] + '\n' +# return text.strip('\n') diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index dd29ee8..743ec38 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -2,36 +2,16 @@ import sqlite3 import time from typing import Union -from flask import jsonify - from llm_server import opts -from llm_server.database import log_prompt -from llm_server.llm.hf_textgen.hf_textgen_backend import HfTextgenLLMBackend -from llm_server.llm.oobabooga.ooba_backend import OobaboogaLLMBackend +from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.cache import redis -from llm_server.routes.helpers.client import format_sillytavern_err -from llm_server.routes.helpers.http import validate_json -from llm_server.routes.queue import priority_queue from llm_server.routes.stats import SemaphoreCheckerThread DEFAULT_PRIORITY = 9999 -def delete_dict_key(d: dict, k: Union[str, list]): - if isinstance(k, str): - if k in d.keys(): - del d[k] - elif isinstance(k, list): - for item in k: - if item in d.keys(): - del d[item] - else: - raise ValueError - return d - - -class OobaRequestHandler: +class RequestHandler: def __init__(self, incoming_request): self.request_json_body = None self.request = incoming_request @@ -39,14 +19,10 @@ class OobaRequestHandler: self.client_ip = self.get_client_ip() self.token = self.request.headers.get('X-Api-Key') self.priority = self.get_priority() - self.backend = self.get_backend() + self.backend = get_backend() self.parameters = self.parameters_invalid_msg = None - - def validate_request(self) -> (bool, Union[str, None]): - # TODO: move this to LLMBackend - if self.parameters.get('max_new_tokens', 0) > opts.max_new_tokens or self.parameters.get('max_tokens', 0) > opts.max_new_tokens: - return False, f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}' - return True, None + self.used = False + SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time() def get_client_ip(self): if self.request.headers.get('cf-connecting-ip'): @@ -67,69 +43,53 @@ class OobaRequestHandler: return result[0] return DEFAULT_PRIORITY - def get_backend(self): - if opts.mode == 'oobabooga': - return OobaboogaLLMBackend() - elif opts.mode == 'vllm': - return VLLMBackend() - else: - raise Exception - - def get_parameters(self): + def load_parameters(self): + # Handle OpenAI + if self.request_json_body.get('max_tokens'): + self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens') self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body) - def handle_request(self): - SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time() - - request_valid_json, self.request_json_body = validate_json(self.request.data) - if not request_valid_json: - return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 - - self.get_parameters() + def validate_request(self): + self.load_parameters() params_valid = False request_valid = False invalid_request_err_msg = None if self.parameters: params_valid = True - request_valid, invalid_request_err_msg = self.validate_request() - - if not request_valid or not params_valid: - error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid and msg] - combined_error_message = ', '.join(error_messages) - err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error') - log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True) - # TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types - return jsonify({ - 'code': 400, - 'msg': 'parameter validation error', - 'results': [{'text': err}] - }), 200 - - # Reconstruct the request JSON with the validated parameters and prompt. - prompt = self.request_json_body.get('prompt', '') - llm_request = {**self.parameters, 'prompt': prompt} + request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters) + return (params_valid, self.parameters_invalid_msg), (request_valid, invalid_request_err_msg) + def is_client_ratelimited(self): queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0) if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0: - event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority) + return False else: - # Client was rate limited - event = None + return True - if not event: - return self.handle_ratelimited() - - event.wait() - success, response, error_msg = event.data - - end_time = time.time() - elapsed_time = end_time - self.start_time - - return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers)) + def handle_request(self): + raise NotImplementedError def handle_ratelimited(self): - backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') - log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, is_error=True) - return jsonify({ - 'results': [{'text': backend_response}] - }), 200 + raise NotImplementedError + + +def get_backend(): + if opts.mode == 'oobabooga': + return OobaboogaBackend() + elif opts.mode == 'vllm': + return VLLMBackend() + else: + raise Exception + + +def delete_dict_key(d: dict, k: Union[str, list]): + if isinstance(k, str): + if k in d.keys(): + del d[k] + elif isinstance(k, list): + for item in k: + if item in d.keys(): + del d[item] + else: + raise ValueError + return d diff --git a/llm_server/routes/stats.py b/llm_server/routes/stats.py index 1454a5f..8d11838 100644 --- a/llm_server/routes/stats.py +++ b/llm_server/routes/stats.py @@ -6,6 +6,7 @@ from llm_server.routes.cache import redis # proompters_1_min = 0 # concurrent_semaphore = Semaphore(concurrent_gens) + server_start_time = datetime.now() # TODO: have a background thread put the averages in a variable so we don't end up with massive arrays diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py index 4b78bf8..69cf2d0 100644 --- a/llm_server/routes/v1/generate.py +++ b/llm_server/routes/v1/generate.py @@ -2,13 +2,13 @@ from flask import jsonify, request from . import bp from ..helpers.http import validate_json -from ..request_handler import OobaRequestHandler +from ..ooba_request_handler import OobaRequestHandler from ... import opts @bp.route('/generate', methods=['POST']) def generate(): - request_valid_json, request_json_body = validate_json(request.data) + request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('prompt'): return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 else: diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index 2959baf..16002e4 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -62,10 +62,10 @@ def generate_stats(): 'power_state': power_state, # 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu'))) } - else: netdata_stats = {} + output = { 'stats': { 'proompters': { diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 71de731..92c9bec 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -8,7 +8,6 @@ from ..helpers.client import format_sillytavern_err from ... import opts from ...database import log_prompt from ...helpers import indefinite_article -from ...llm.hf_textgen.generate import prepare_json from ...stream import sock diff --git a/llm_server/threads.py b/llm_server/threads.py index a5a6162..202e3cd 100644 --- a/llm_server/threads.py +++ b/llm_server/threads.py @@ -1,4 +1,5 @@ import time +from datetime import datetime from threading import Thread import requests diff --git a/server.py b/server.py index f6b9f42..a455ef0 100644 --- a/server.py +++ b/server.py @@ -6,6 +6,7 @@ from threading import Thread from flask import Flask, jsonify, render_template, request +from llm_server.routes.openai import openai_bp from llm_server.routes.server_error import handle_server_error try: @@ -68,6 +69,9 @@ opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip'] opts.show_backend_info = config['show_backend_info'] opts.max_new_tokens = config['max_new_tokens'] opts.manual_model_name = config['manual_model_name'] +opts.llm_middleware_name = config['llm_middleware_name'] +opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend'] +opts.openai_system_prompt = config['openai_system_prompt'] opts.verify_ssl = config['verify_ssl'] if not opts.verify_ssl: @@ -105,6 +109,7 @@ init_socketio(app) # with app.app_context(): # current_app.tokenizer = tiktoken.get_encoding("cl100k_base") app.register_blueprint(bp, url_prefix='/api/v1/') +app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/') # print(app.url_map) @@ -145,7 +150,7 @@ def home(): mode_info = vllm_info return render_template('home.html', - llm_middleware_name=config['llm_middleware_name'], + llm_middleware_name=opts.llm_middleware_name, analytics_tracking_code=analytics_tracking_code, info_html=info_html, current_model=opts.manual_model_name if opts.manual_model_name else running_model, @@ -158,6 +163,7 @@ def home(): context_size=opts.context_size, stats_json=json.dumps(stats, indent=4, ensure_ascii=False), extra_info=mode_info, + openai_client_api=f'https://{opts.base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled', ) diff --git a/templates/home.html b/templates/home.html index 530e4a4..02d2f61 100644 --- a/templates/home.html +++ b/templates/home.html @@ -76,6 +76,7 @@

Current Model: {{ current_model }}

Client API URL: {{ client_api }}

Streaming API URL: {{ ws_client_api }}

+

OpenAI-Compatible API URL: {{ openai_client_api }}

Estimated Wait Time: {{ estimated_wait }}

{{ info_html|safe }}