diff --git a/README.md b/README.md index 4e827ca..c95e083 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,8 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database. ### Use +If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis. + Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app` diff --git a/llm_server/cluster/model_choices.py b/llm_server/cluster/model_choices.py index 4b02b97..31cd8cb 100644 --- a/llm_server/cluster/model_choices.py +++ b/llm_server/cluster/model_choices.py @@ -60,7 +60,8 @@ def get_model_choices(regen: bool = False): if len(context_size): model_choices[model]['context_size'] = min(context_size) - model_choices = dict(sorted(model_choices.items())) + # Python wants to sort lowercase vs. uppercase letters differently. + model_choices = dict(sorted(model_choices.items(), key=lambda item: item[0].upper())) default_backend = get_a_cluster_backend() default_backend_dict = {} diff --git a/llm_server/llm/llm_backend.py b/llm_server/llm/llm_backend.py index e69f8fc..ccc8db8 100644 --- a/llm_server/llm/llm_backend.py +++ b/llm_server/llm/llm_backend.py @@ -2,7 +2,6 @@ from typing import Tuple, Union import flask -from llm_server import opts from llm_server.cluster.cluster_config import cluster_config from llm_server.custom_redis import redis from llm_server.llm import get_token_count @@ -36,6 +35,8 @@ class LLMBackend: """ If a backend needs to do other checks not related to the prompt or parameters. Default is no extra checks preformed. + :param request: + :param prompt: :param parameters: :return: """ diff --git a/llm_server/llm/openai/oai_to_vllm.py b/llm_server/llm/openai/oai_to_vllm.py new file mode 100644 index 0000000..5f58da5 --- /dev/null +++ b/llm_server/llm/openai/oai_to_vllm.py @@ -0,0 +1,63 @@ +from flask import jsonify + +from llm_server import opts + + +def oai_to_vllm(request_json_body, hashes: bool, mode): + if not request_json_body.get('stop'): + request_json_body['stop'] = [] + + if hashes: + request_json_body['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE']) + if opts.openai_force_no_hashes: + request_json_body['stop'].append('### ') + else: + request_json_body['stop'].extend(['\nuser:', '\nassistant:']) + + if request_json_body.get('frequency_penalty', 0) < -2: + request_json_body['frequency_penalty'] = -2 + elif request_json_body.get('frequency_penalty', 0) > 2: + request_json_body['frequency_penalty'] = 2 + + if mode == 'vllm' and request_json_body.get('top_p') == 0: + request_json_body['top_p'] = 0.01 + + return request_json_body + + +def format_oai_err(err_msg): + return jsonify({ + "error": { + "message": err_msg, + "type": "invalid_request_error", + "param": None, + "code": None + } + }), 400 + + +def validate_oai(parameters): + if parameters['temperature'] > 2: + return format_oai_err(f"{parameters['temperature']} is greater than the maximum of 2 - 'temperature'") + if parameters['temperature'] < 0: + return format_oai_err(f"{parameters['temperature']} less than the minimum of 0 - 'temperature'") + + if parameters.get('top_p', 1) > 2: + return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'") + if parameters.get('top_p', 1) < 0: + return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'") + + if parameters.get('presence_penalty', 1) > 2: + return format_oai_err(f"{parameters['presence_penalty']} is greater than the maximum of 2 - 'presence_penalty'") + if parameters.get('presence_penalty', 1) < -2: + return format_oai_err(f"{parameters['presence_penalty']} less than the minimum of -2 - 'presence_penalty'") + + if parameters.get('top_p', 1) > 2: + return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'") + if parameters.get('top_p', 1) < 0: + return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'") + + if parameters.get('top_p', 1) > 2: + return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'") + if parameters.get('top_p', 1) < 0: + return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'") diff --git a/llm_server/llm/openai/transform.py b/llm_server/llm/openai/transform.py index 62e0ed8..0100c7f 100644 --- a/llm_server/llm/openai/transform.py +++ b/llm_server/llm/openai/transform.py @@ -2,73 +2,24 @@ import concurrent.futures import re import secrets import string -import time import traceback from typing import Dict, List import tiktoken -from flask import jsonify, make_response -import llm_server from llm_server import opts from llm_server.llm import get_token_count -from llm_server.custom_redis import redis ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line. ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line. -def build_openai_response(prompt, response, model=None): - # Seperate the user's prompt from the context - x = prompt.split('### USER:') - if len(x) > 1: - prompt = re.sub(r'\n$', '', x[-1].strip(' ')) - - # Make sure the bot doesn't put any other instructions in its response - # y = response.split('\n### ') - # if len(y) > 1: - # response = re.sub(r'\n$', '', y[0].strip(' ')) - response = re.sub(ANTI_RESPONSE_RE, '', response) - response = re.sub(ANTI_CONTINUATION_RE, '', response) - - # TODO: async/await - prompt_tokens = llm_server.llm.get_token_count(prompt) - response_tokens = llm_server.llm.get_token_count(response) - running_model = redis.get('running_model', 'ERROR', dtype=str) - - response = make_response(jsonify({ - "id": f"chatcmpl-{generate_oai_string(30)}", - "object": "chat.completion", - "created": int(time.time()), - "model": running_model if opts.openai_expose_our_model else model, - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": response, - }, - "logprobs": None, - "finish_reason": "stop" - }], - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": response_tokens, - "total_tokens": prompt_tokens + response_tokens - } - }), 200) - - stats = redis.get('proxy_stats', dtype=dict) - if stats: - response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] - return response - - def generate_oai_string(length=24): alphabet = string.ascii_letters + string.digits return ''.join(secrets.choice(alphabet) for i in range(length)) -def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -> List[Dict[str, str]]: +def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]: tokenizer = tiktoken.get_encoding("cl100k_base") def get_token_count_tiktoken_thread(msg): @@ -95,13 +46,13 @@ def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) - break def get_token_count_thread(msg): - return get_token_count(msg["content"]) + return get_token_count(msg["content"], backend_url) with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: token_counts = list(executor.map(get_token_count_thread, prompt)) total_tokens = sum(token_counts) - formatting_tokens = get_token_count(transform_messages_to_prompt(prompt)) - total_tokens + formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens if total_tokens + formatting_tokens > context_token_limit: # Start over, but this time calculate the token count using the backend @@ -109,6 +60,40 @@ def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) - token_counts = list(executor.map(get_token_count_thread, prompt)) else: break + return prompt + + +def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str: + tokenizer = tiktoken.get_encoding("cl100k_base") + + def get_token_count_tiktoken_thread(msg): + return len(tokenizer.encode(msg)) + + token_count = get_token_count_tiktoken_thread(prompt) + + # If total tokens exceed the limit, start trimming + if token_count > context_token_limit: + while True: + while token_count > context_token_limit: + # Calculate the index to start removing characters from + remove_index = len(prompt) // 3 + + while remove_index < len(prompt): + prompt = prompt[:remove_index] + prompt[remove_index + 100:] + token_count = get_token_count_tiktoken_thread(prompt) + if token_count <= context_token_limit or remove_index == len(prompt): + break + + def get_token_count_thread(msg): + return get_token_count(msg, backend_url) + + token_count = get_token_count_thread(prompt) + + if token_count > context_token_limit: + # Start over, but this time calculate the token count using the backend + token_count = get_token_count_thread(prompt) + else: + break return prompt diff --git a/llm_server/llm/vllm/vllm_backend.py b/llm_server/llm/vllm/vllm_backend.py index a28e59a..abc1cbb 100644 --- a/llm_server/llm/vllm/vllm_backend.py +++ b/llm_server/llm/vllm/vllm_backend.py @@ -1,4 +1,3 @@ -import threading from typing import Tuple from flask import jsonify @@ -35,9 +34,11 @@ class VLLMBackend(LLMBackend): top_p=parameters.get('top_p', self._default_params['top_p']), top_k=top_k, use_beam_search=True if parameters.get('num_beams', 0) > 1 else False, - stop=parameters.get('stopping_strings', self._default_params['stop']), + stop=list(set(parameters.get('stopping_strings', self._default_params['stop']) or parameters.get('stop', self._default_params['stop']))), ignore_eos=parameters.get('ban_eos_token', False), - max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens']) + max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens']) or parameters.get('max_tokens', self._default_params['max_tokens']), + presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']), + frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty']) ) except ValueError as e: return None, str(e).strip('.') diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index b3159a5..0e716e9 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -1,5 +1,4 @@ import json -import threading import time import traceback @@ -10,11 +9,10 @@ from . import openai_bp from ..helpers.http import validate_json from ..openai_request_handler import OpenAIRequestHandler from ... import opts -from ...cluster.backend import get_a_cluster_backend from ...database.database import log_prompt from ...llm.generator import generator -from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt -from ...llm.vllm import tokenize +from ...llm.openai.oai_to_vllm import oai_to_vllm +from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit # TODO: add rate-limit headers? @@ -25,32 +23,46 @@ def openai_chat_completions(): if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'): return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 else: - handler = OpenAIRequestHandler(request, request_json_body) - if request_json_body.get('stream'): + handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body) + + if handler.cluster_backend_info['mode'] != 'vllm': + # TODO: implement other backends + raise NotImplementedError + + if not request_json_body.get('stream'): + try: + return handler.handle_request() + except Exception: + traceback.print_exc() + return 'Internal server error', 500 + else: if not opts.enable_streaming: # TODO: return a proper OAI error message return 'disabled', 401 - if opts.mode != 'vllm': - # TODO: implement other backends - raise NotImplementedError + if opts.openai_silent_trim: + handler.request_json_body['messages'] = trim_messages_to_fit(request_json_body['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) response_status_code = 0 start_time = time.time() request_valid, invalid_response = handler.validate_request() if not request_valid: - # TODO: simulate OAI here - raise Exception('TODO: simulate OAI here') + return invalid_response else: - handler.prompt = transform_messages_to_prompt(request_json_body['messages']) + if opts.openai_silent_trim: + oai_messages = trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) + else: + oai_messages = handler.request.json['messages'] + + handler.prompt = transform_messages_to_prompt(oai_messages) + handler.parameters = oai_to_vllm(handler.parameters, hashes=True, mode=handler.cluster_backend_info['mode']) msg_to_backend = { **handler.parameters, 'prompt': handler.prompt, 'stream': True, } try: - cluster_backend = get_a_cluster_backend() - response = generator(msg_to_backend, cluster_backend) + response = generator(msg_to_backend, handler.backend_url) r_headers = dict(request.headers) r_url = request.url model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') @@ -94,22 +106,20 @@ def openai_chat_completions(): end_time = time.time() elapsed_time = end_time - start_time - def background_task(): - generated_tokens = tokenize(generated_text) - log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=generated_tokens) - - # TODO: use async/await instead of threads - thread = threading.Thread(target=background_task) - thread.start() - thread.join() + log_prompt( + handler.client_ip, + handler.token, + handler.prompt, + generated_text, + elapsed_time, + handler.parameters, + r_headers, + response_status_code, + r_url, + handler.backend_url, + ) return Response(generate(), mimetype='text/event-stream') except: # TODO: simulate OAI here raise Exception - else: - try: - return handler.handle_request() - except Exception: - traceback.print_exc() - return 'Internal server error', 500 diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index 8950927..41d1d3b 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -1,15 +1,19 @@ import time import traceback -from flask import jsonify, make_response, request +import simplejson as json +from flask import Response, jsonify, request -from . import openai_bp from llm_server.custom_redis import redis +from . import openai_bp from ..helpers.http import validate_json from ..ooba_request_handler import OobaRequestHandler from ... import opts +from ...database.database import log_prompt from ...llm import get_token_count -from ...llm.openai.transform import generate_oai_string +from ...llm.generator import generator +from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai +from ...llm.openai.transform import generate_oai_string, trim_string_to_fit # TODO: add rate-limit headers? @@ -21,40 +25,137 @@ def openai_completions(): return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 else: try: - response, status_code = OobaRequestHandler(request).handle_request() - if status_code != 200: - return status_code - output = response.json['results'][0]['text'] + handler = OobaRequestHandler(incoming_request=request) - # TODO: async/await - prompt_tokens = get_token_count(request_json_body['prompt']) - response_tokens = get_token_count(output) - running_model = redis.get('running_model', 'ERROR', dtype=str) + if handler.cluster_backend_info['mode'] != 'vllm': + # TODO: implement other backends + raise NotImplementedError - response = make_response(jsonify({ - "id": f"cmpl-{generate_oai_string(30)}", - "object": "text_completion", - "created": int(time.time()), - "model": running_model if opts.openai_expose_our_model else request_json_body.get('model'), - "choices": [ - { - "text": output, - "index": 0, - "logprobs": None, - "finish_reason": None + invalid_oai_err_msg = validate_oai(handler.request_json_body) + if invalid_oai_err_msg: + return invalid_oai_err_msg + handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode']) + + # Convert parameters to the selected backend type + if opts.openai_silent_trim: + handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) + else: + # The handle_request() call below will load the prompt so we don't have + # to do anything else here. + pass + + if not request_json_body.get('stream'): + response, status_code = handler.handle_request() + if status_code != 200: + return status_code + output = response.json['results'][0]['text'] + + # TODO: async/await + prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url) + response_tokens = get_token_count(output, handler.backend_url) + running_model = redis.get('running_model', 'ERROR', dtype=str) + + response = jsonify({ + "id": f"cmpl-{generate_oai_string(30)}", + "object": "text_completion", + "created": int(time.time()), + "model": running_model if opts.openai_expose_our_model else request_json_body.get('model'), + "choices": [ + { + "text": output, + "index": 0, + "logprobs": None, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": response_tokens, + "total_tokens": prompt_tokens + response_tokens } - ], - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": response_tokens, - "total_tokens": prompt_tokens + response_tokens - } - }), 200) + }) - stats = redis.get('proxy_stats', dtype=dict) - if stats: - response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] - return response + stats = redis.get('proxy_stats', dtype=dict) + if stats: + response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] + return response, 200 + else: + if not opts.enable_streaming: + # TODO: return a proper OAI error message + return 'disabled', 401 + + response_status_code = 0 + start_time = time.time() + + request_valid, invalid_response = handler.validate_request() + if not request_valid: + # TODO: simulate OAI here + raise Exception('TODO: simulate OAI here') + else: + handler.prompt = handler.request_json_body['prompt'] + msg_to_backend = { + **handler.parameters, + 'prompt': handler.prompt, + 'stream': True, + } + response = generator(msg_to_backend, handler.backend_url) + r_headers = dict(request.headers) + r_url = request.url + model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') + oai_string = generate_oai_string(30) + + def generate(): + generated_text = '' + partial_response = b'' + for chunk in response.iter_content(chunk_size=1): + partial_response += chunk + if partial_response.endswith(b'\x00'): + json_strs = partial_response.split(b'\x00') + for json_str in json_strs: + if json_str: + try: + json_obj = json.loads(json_str.decode()) + new = json_obj['text'][0].split(handler.prompt + generated_text)[1] + generated_text = generated_text + new + except IndexError: + # ???? + continue + + data = { + "id": f"chatcmpl-{oai_string}", + "object": "text_completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": new + }, + "finish_reason": None + } + ] + } + yield f'data: {json.dumps(data)}\n\n' + + yield 'data: [DONE]\n\n' + end_time = time.time() + elapsed_time = end_time - start_time + + log_prompt( + handler.client_ip, + handler.token, + handler.prompt, + generated_text, + elapsed_time, + handler.parameters, + r_headers, + response_status_code, + r_url, + handler.backend_url, + ) + + return Response(generate(), mimetype='text/event-stream') except Exception: traceback.print_exc() return 'Internal Server Error', 500 diff --git a/llm_server/routes/openai/models.py b/llm_server/routes/openai/models.py index 657f084..39931f8 100644 --- a/llm_server/routes/openai/models.py +++ b/llm_server/routes/openai/models.py @@ -3,24 +3,24 @@ import traceback import requests from flask import jsonify -from . import openai_bp from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis +from . import openai_bp from ..stats import server_start_time from ... import opts from ...cluster.backend import get_a_cluster_backend +from ...cluster.cluster_config import cluster_config from ...helpers import jsonify_pretty -from ...llm.info import get_running_model +from ...llm.openai.transform import generate_oai_string @openai_bp.route('/models', methods=['GET']) @flask_cache.cached(timeout=60, query_string=True) def openai_list_models(): - model, error = get_running_model() - if not model: + model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model') + if not model_name: response = jsonify({ 'code': 502, 'msg': 'failed to reach backend', - 'type': error.__class__.__name__ }), 500 # return 500 so Cloudflare doesn't intercept us else: running_model = redis.get('running_model', 'ERROR', dtype=str) @@ -65,7 +65,14 @@ def fetch_openai_models(): if opts.openai_api_key: try: response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10) - return response.json()['data'] + j = response.json()['data'] + + # The "modelperm" string appears to be user-specific, so we'll + # randomize it just to be safe. + for model in range(len(j)): + for p in range(len(j[model]['permission'])): + j[model]['permission'][p]['id'] = f'modelperm-{generate_oai_string(24)}' + return j except: traceback.print_exc() return [] diff --git a/llm_server/routes/openai/simulated.py b/llm_server/routes/openai/simulated.py index 301e8de..2dafedb 100644 --- a/llm_server/routes/openai/simulated.py +++ b/llm_server/routes/openai/simulated.py @@ -17,7 +17,7 @@ def openai_organizations(): "id": f"org-{generate_oai_string(24)}", "created": int(server_start_time.timestamp()), "title": "Personal", - "name": "user-abcdefghijklmnopqrstuvwx", + "name": f"user-{generate_oai_string(24)}", "description": "Personal org for bobjoe@0.0.0.0", "personal": True, "is_default": True, diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index d97ea09..6b9ff98 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -1,14 +1,19 @@ import json +import re +import time import traceback from typing import Tuple from uuid import uuid4 import flask -from flask import jsonify +from flask import Response, jsonify, make_response +import llm_server from llm_server import opts +from llm_server.custom_redis import redis from llm_server.database.database import is_api_key_moderated -from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit +from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai +from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from llm_server.routes.request_handler import RequestHandler from llm_server.workers.moderator import add_moderation_task, get_results @@ -22,7 +27,7 @@ class OpenAIRequestHandler(RequestHandler): assert not self.used if opts.openai_silent_trim: - oai_messages = trim_prompt_to_fit(self.request.json['messages'], opts.context_size) + oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url) else: oai_messages = self.request.json['messages'] @@ -51,13 +56,8 @@ class OpenAIRequestHandler(RequestHandler): print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}') print(traceback.format_exc()) - # Reconstruct the request JSON with the validated parameters and prompt. - self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE']) - if opts.openai_force_no_hashes: - self.parameters['stop'].append('### ') - - if opts.mode == 'vllm' and self.request_json_body.get('top_p') == 0: - self.request_json_body['top_p'] = 0.01 + # TODO: support Ooba + self.parameters = oai_to_vllm(self.parameters, hashes=True, mode=self.cluster_backend_info['mode']) llm_request = {**self.parameters, 'prompt': self.prompt} (success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request) @@ -65,7 +65,7 @@ class OpenAIRequestHandler(RequestHandler): model = self.request_json_body.get('model') if success: - return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code + return self.build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code else: return backend_response, backend_response_status_code @@ -75,7 +75,6 @@ class OpenAIRequestHandler(RequestHandler): return 'Ratelimited', 429 def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: - # TODO: return a simulated OpenAI error message return jsonify({ "error": { "message": "Invalid request, check your parameters and try again.", @@ -84,3 +83,52 @@ class OpenAIRequestHandler(RequestHandler): "code": None } }), 400 + + def build_openai_response(self, prompt, response, model=None): + # Seperate the user's prompt from the context + x = prompt.split('### USER:') + if len(x) > 1: + prompt = re.sub(r'\n$', '', x[-1].strip(' ')) + + # Make sure the bot doesn't put any other instructions in its response + response = re.sub(ANTI_RESPONSE_RE, '', response) + response = re.sub(ANTI_CONTINUATION_RE, '', response) + + # TODO: async/await + prompt_tokens = llm_server.llm.get_token_count(prompt, self.backend_url) + response_tokens = llm_server.llm.get_token_count(response, self.backend_url) + running_model = redis.get('running_model', 'ERROR', dtype=str) + + response = make_response(jsonify({ + "id": f"chatcmpl-{generate_oai_string(30)}", + "object": "chat.completion", + "created": int(time.time()), + "model": running_model if opts.openai_expose_our_model else model, + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": response, + }, + "logprobs": None, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": response_tokens, + "total_tokens": prompt_tokens + response_tokens + } + }), 200) + + stats = redis.get('proxy_stats', dtype=dict) + if stats: + response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] + return response + + def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]: + invalid_oai_err_msg = validate_oai(self.request_json_body) + if invalid_oai_err_msg: + return False, invalid_oai_err_msg + self.request_json_body = oai_to_vllm(self.request_json_body, hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) + # If the parameters were invalid, let the superclass deal with it. + return super().validate_request(prompt, do_log) diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index 0dd862a..ff83e76 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -15,13 +15,13 @@ from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.auth import parse_token from llm_server.routes.helpers.http import require_api_key, validate_json -from llm_server.routes.queue import RedisPriorityQueue, priority_queue +from llm_server.routes.queue import priority_queue DEFAULT_PRIORITY = 9999 class RequestHandler: - def __init__(self, incoming_request: flask.Request, selected_model: str, incoming_json: Union[dict, str] = None): + def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None): self.request = incoming_request self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true' @@ -41,7 +41,7 @@ class RequestHandler: self.cluster_backend_info = cluster_config.get_backend(self.backend_url) if not self.cluster_backend_info.get('mode'): - print(self.backend_url, self.cluster_backend_info) + print(selected_model, self.backend_url, self.cluster_backend_info) self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url) self.parameters = None diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py index 39db078..1a63db9 100644 --- a/llm_server/routes/v1/generate.py +++ b/llm_server/routes/v1/generate.py @@ -5,8 +5,6 @@ from flask import jsonify, request from . import bp from ..helpers.http import validate_json from ..ooba_request_handler import OobaRequestHandler -from ...cluster.backend import get_a_cluster_backend -from ...cluster.cluster_config import cluster_config @bp.route('/v1/generate', methods=['POST']) diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index 30e0967..500f015 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -71,6 +71,7 @@ def generate_stats(regen: bool = False): 'model': backend_info['model'], 'mode': backend_info['mode'], 'nvidia': backend_info['nvidia'], + 'priority': backend_info['priority'], } else: output['backend_info'] = {} diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 24d5bc6..c9b9c0d 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -84,7 +84,7 @@ def do_stream(ws, model_name): ws.close() return auth_failure - handler = OobaRequestHandler(request, model_name, request_json_body) + handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body) generated_text = '' input_prompt = request_json_body['prompt'] response_status_code = 0 diff --git a/llm_server/routes/v1/info.py b/llm_server/routes/v1/info.py index 355b415..df4e3be 100644 --- a/llm_server/routes/v1/info.py +++ b/llm_server/routes/v1/info.py @@ -4,10 +4,8 @@ from flask import jsonify, request from llm_server.custom_redis import flask_cache from . import bp -from ..auth import requires_auth from ... import opts -from ...cluster.backend import get_a_cluster_backend, get_backends, get_backends_from_model, is_valid_model -from ...cluster.cluster_config import cluster_config +from ...cluster.backend import get_a_cluster_backend, get_backends_from_model, is_valid_model @bp.route('/v1/model', methods=['GET']) @@ -39,14 +37,3 @@ def get_model(model_name=None): flask_cache.set(cache_key, response, timeout=60) return response - - -@bp.route('/backends', methods=['GET']) -@requires_auth -def get_backend(): - online, offline = get_backends() - result = {} - for i in online + offline: - info = cluster_config.get_backend(i) - result[info['hash']] = info - return jsonify(result), 200 diff --git a/llm_server/routes/v1/proxy.py b/llm_server/routes/v1/proxy.py index 5ffd194..e5ff5d3 100644 --- a/llm_server/routes/v1/proxy.py +++ b/llm_server/routes/v1/proxy.py @@ -1,6 +1,11 @@ +from flask import jsonify + +from llm_server.custom_redis import flask_cache from . import bp from .generate_stats import generate_stats -from llm_server.custom_redis import flask_cache +from ..auth import requires_auth +from ...cluster.backend import get_backends +from ...cluster.cluster_config import cluster_config from ...helpers import jsonify_pretty @@ -8,3 +13,14 @@ from ...helpers import jsonify_pretty @flask_cache.cached(timeout=5, query_string=True) def get_stats(): return jsonify_pretty(generate_stats()) + + +@bp.route('/backends', methods=['GET']) +@requires_auth +def get_backend(): + online, offline = get_backends() + result = {} + for i in online + offline: + info = cluster_config.get_backend(i) + result[info['hash']] = info + return jsonify(result), 200 diff --git a/server.py b/server.py index 71685a4..4191a84 100644 --- a/server.py +++ b/server.py @@ -24,11 +24,13 @@ from llm_server.routes.server_error import handle_server_error from llm_server.routes.v1 import bp from llm_server.sock import init_socketio -# TODO: add a way to cancel VLLM gens. Maybe use websockets? -# TODO: need to update opts. for workers -# TODO: add a healthcheck to VLLM +# TODO: make sure openai_moderation_enabled works on websockets, completions, and chat completions # Lower priority +# TODO: support logit_bias on OpenAI and Ooba endpoints. +# TODO: add a way to cancel VLLM gens. Maybe use websockets? +# TODO: validate openai_silent_trim works as expected and only when enabled +# TODO: rewrite config storage. Store in redis so we can reload it. # TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation. # TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens # TODO: the estiamted wait time lags behind the stats