From 79b1e01b61ebf00c4be53a245766b01ea8a2ad1f Mon Sep 17 00:00:00 2001 From: Cyberes Date: Thu, 14 Sep 2023 14:05:50 -0600 Subject: [PATCH] option to disable streaming, improve timeout on requests to backend, fix error handling. reduce duplicate code, misc other cleanup --- llm_server/config.py | 1 + llm_server/llm/info.py | 4 +- llm_server/llm/llm_backend.py | 4 +- llm_server/llm/oobabooga/generate.py | 6 +- llm_server/llm/oobabooga/ooba_backend.py | 39 ++++---- llm_server/llm/vllm/generate.py | 24 ++--- llm_server/llm/vllm/vllm_backend.py | 76 ++------------- llm_server/netdata.py | 2 +- llm_server/opts.py | 4 + llm_server/routes/cache.py | 29 +++--- llm_server/routes/ooba_request_handler.py | 39 ++------ llm_server/routes/openai/chat_completions.py | 3 +- llm_server/routes/openai_request_handler.py | 52 ++--------- llm_server/routes/request_handler.py | 97 ++++++++++++++++++-- llm_server/routes/v1/generate_stats.py | 2 +- llm_server/routes/v1/generate_stream.py | 5 +- llm_server/threads.py | 13 +-- other/vllm/README.md | 9 ++ server.py | 9 +- templates/home.html | 5 +- 20 files changed, 194 insertions(+), 229 deletions(-) create mode 100644 other/vllm/README.md diff --git a/llm_server/config.py b/llm_server/config.py index 95850eb..76203a9 100644 --- a/llm_server/config.py +++ b/llm_server/config.py @@ -17,6 +17,7 @@ config_default_vars = { 'show_backend_info': True, 'max_new_tokens': 500, 'manual_model_name': False, + 'enable_streaming': True, 'enable_openi_compatible_backend': True, 'expose_openai_system_prompt': True, 'openai_system_prompt': """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""", diff --git a/llm_server/llm/info.py b/llm_server/llm/info.py index b26093d..5a529ba 100644 --- a/llm_server/llm/info.py +++ b/llm_server/llm/info.py @@ -9,14 +9,14 @@ def get_running_model(): if opts.mode == 'oobabooga': try: - backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=10, verify=opts.verify_ssl) + backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) r_json = backend_response.json() return r_json['result'], None except Exception as e: return False, e elif opts.mode == 'vllm': try: - backend_response = requests.get(f'{opts.backend_url}/model', timeout=10, verify=opts.verify_ssl) + backend_response = requests.get(f'{opts.backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) r_json = backend_response.json() return r_json['model'], None except Exception as e: diff --git a/llm_server/llm/llm_backend.py b/llm_server/llm/llm_backend.py index d4dbecb..6dd5874 100644 --- a/llm_server/llm/llm_backend.py +++ b/llm_server/llm/llm_backend.py @@ -6,7 +6,7 @@ import flask class LLMBackend: default_params: dict - def handle_response(self, success, request: flask.Request, response: flask.Response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers): + def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers): raise NotImplementedError def validate_params(self, params_dict: dict) -> Tuple[bool, str | None]: @@ -24,5 +24,5 @@ class LLMBackend: """ raise NotImplementedError - def validate_request(self, parameters: dict) -> (bool, Union[str, None]): + def validate_request(self, parameters: dict) -> Tuple[bool, Union[str, None]]: raise NotImplementedError diff --git a/llm_server/llm/oobabooga/generate.py b/llm_server/llm/oobabooga/generate.py index c4736af..2a533f3 100644 --- a/llm_server/llm/oobabooga/generate.py +++ b/llm_server/llm/oobabooga/generate.py @@ -9,9 +9,11 @@ from llm_server import opts def generate(json_data: dict): try: - r = requests.post(f'{opts.backend_url}/api/v1/generate', json=json_data, verify=opts.verify_ssl, timeout=120) + r = requests.post(f'{opts.backend_url}/api/v1/generate', json=json_data, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) + except requests.exceptions.ReadTimeout: + return False, None, 'Request to backend timed out' except Exception as e: - return False, None, f'{e.__class__.__name__}: {e}' + return False, None, 'Request to backend encountered error' if r.status_code != 200: return False, r, f'Backend returned {r.status_code}' return True, r, None diff --git a/llm_server/llm/oobabooga/ooba_backend.py b/llm_server/llm/oobabooga/ooba_backend.py index 28dda5e..48a7336 100644 --- a/llm_server/llm/oobabooga/ooba_backend.py +++ b/llm_server/llm/oobabooga/ooba_backend.py @@ -9,25 +9,38 @@ from ...routes.helpers.http import validate_json class OobaboogaBackend(LLMBackend): + default_params = {} + def handle_response(self, success, request, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers): + raise NotImplementedError('need to implement default_params') + backend_err = False response_valid_json, response_json_body = validate_json(response) - try: - # Be extra careful when getting attributes from the response object - response_status_code = response.status_code - except: - response_status_code = 0 + if response: + try: + # Be extra careful when getting attributes from the response object + response_status_code = response.status_code + except: + response_status_code = 0 + else: + response_status_code = None # =============================================== + # We encountered an error - if not success or not response: - backend_response = format_sillytavern_err(f'Failed to reach the backend (oobabooga): {error_msg}', 'error') - log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response if response else 0, request.url, is_error=True) + if not success or not response or error_msg: + if not error_msg or error_msg == '': + error_msg = 'Unknown error.' + else: + error_msg = error_msg.strip('.') + '.' + backend_response = format_sillytavern_err(error_msg, 'error') + log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True) return jsonify({ 'code': 500, - 'msg': 'failed to reach backend', + 'msg': error_msg, 'results': [{'text': backend_response}] }), 200 + # =============================================== if response_valid_json: @@ -60,14 +73,6 @@ class OobaboogaBackend(LLMBackend): # No validation required return True, None - # def get_model_info(self) -> Tuple[dict | bool, Exception | None]: - # try: - # backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=3, verify=opts.verify_ssl) - # r_json = backend_response.json() - # return r_json['result'], None - # except Exception as e: - # return False, e - def get_parameters(self, parameters): del parameters['prompt'] return parameters diff --git a/llm_server/llm/vllm/generate.py b/llm_server/llm/vllm/generate.py index 0a689d8..8e580e4 100644 --- a/llm_server/llm/vllm/generate.py +++ b/llm_server/llm/vllm/generate.py @@ -79,25 +79,13 @@ def transform_prompt_to_text(prompt: list): def handle_blocking_request(json_data: dict): try: - r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=120) + r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) + except requests.exceptions.ReadTimeout: + return False, None, 'Request to backend timed out' except Exception as e: - return False, None, f'{e.__class__.__name__}: {e}' - - # TODO: check for error here? - # response_json = r.json() - # response_json['error'] = False - - # new_response = Response() - # new_response.status_code = r.status_code - # new_response._content = json.dumps(response_json).encode('utf-8') - # new_response.raw = io.BytesIO(new_response._content) - # new_response.headers = r.headers - # new_response.url = r.url - # new_response.reason = r.reason - # new_response.cookies = r.cookies - # new_response.elapsed = r.elapsed - # new_response.request = r.request - + return False, None, 'Request to backend encountered error' # f'{e.__class__.__name__}: {e}' + if r.status_code != 200: + return False, r, f'Backend returned {r.status_code}' return True, r, None diff --git a/llm_server/llm/vllm/vllm_backend.py b/llm_server/llm/vllm/vllm_backend.py index 94f39a7..f2fc82d 100644 --- a/llm_server/llm/vllm/vllm_backend.py +++ b/llm_server/llm/vllm/vllm_backend.py @@ -6,80 +6,22 @@ from vllm import SamplingParams from llm_server import opts from llm_server.database import log_prompt from llm_server.llm.llm_backend import LLMBackend -from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.helpers.http import validate_json -# https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py - -# TODO: https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/proxy/middleware/common.ts?ref_type=heads#L69 - class VLLMBackend(LLMBackend): default_params = vars(SamplingParams()) - def handle_response(self, success, request, response, error_msg, client_ip, token, prompt: str, elapsed_time, parameters, headers): - response_valid_json, response_json_body = validate_json(response) - backend_err = False - try: - response_status_code = response.status_code - except: - response_status_code = 0 - - if response_valid_json: - if len(response_json_body.get('text', [])): - # Does vllm return the prompt and the response together??? - backend_response = response_json_body['text'][0].split(prompt)[1].strip(' ').strip('\n') - else: - # Failsafe - backend_response = '' - - # TODO: how to detect an error? - # if backend_response == '': - # backend_err = True - # backend_response = format_sillytavern_err( - # f'Backend (vllm-gptq) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.', - # f'HTTP CODE {response_status_code}' - # ) - - log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time if not backend_err else None, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err) - return jsonify({'results': [{'text': backend_response}]}), 200 + def handle_response(self, success, request, response_json_body, response_status_code, client_ip, token, prompt: str, elapsed_time, parameters, headers): + if len(response_json_body.get('text', [])): + # Does vllm return the prompt and the response together??? + backend_response = response_json_body['text'][0].split(prompt)[1].strip(' ').strip('\n') else: - backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error') - log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code if response else None, request.url, is_error=True) - return jsonify({ - 'code': 500, - 'msg': 'the backend did not return valid JSON', - 'results': [{'text': backend_response}] - }), 200 - - # def validate_params(self, params_dict: dict): - # self.default_params = SamplingParams() - # try: - # sampling_params = SamplingParams( - # temperature=params_dict.get('temperature', self.default_paramstemperature), - # top_p=params_dict.get('top_p', self.default_paramstop_p), - # top_k=params_dict.get('top_k', self.default_paramstop_k), - # use_beam_search=True if params_dict['num_beams'] > 1 else False, - # length_penalty=params_dict.get('length_penalty', self.default_paramslength_penalty), - # early_stopping=params_dict.get('early_stopping', self.default_paramsearly_stopping), - # stop=params_dict.get('stopping_strings', self.default_paramsstop), - # ignore_eos=params_dict.get('ban_eos_token', False), - # max_tokens=params_dict.get('max_new_tokens', self.default_paramsmax_tokens) - # ) - # except ValueError as e: - # print(e) - # return False, e - # return True, None - - # def get_model_info(self) -> Tuple[dict | bool, Exception | None]: - # try: - # backend_response = requests.get(f'{opts.backend_url}/api/v1/models', timeout=3, verify=opts.verify_ssl) - # r_json = backend_response.json() - # model_path = Path(r_json['data'][0]['root']).name - # r_json['data'][0]['root'] = model_path - # return r_json, None - # except Exception as e: - # return False, e + # Failsafe + backend_response = '' + log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url, + response_tokens=response_json_body.get('details', {}).get('generated_tokens')) + return jsonify({'results': [{'text': backend_response}]}), 200 def get_parameters(self, parameters) -> Tuple[dict | None, str | None]: try: diff --git a/llm_server/netdata.py b/llm_server/netdata.py index d3abf32..f37c109 100644 --- a/llm_server/netdata.py +++ b/llm_server/netdata.py @@ -12,7 +12,7 @@ def get_power_states(): while True: url = f"{opts.netdata_root}/api/v1/data?chart=nvidia_smi.gpu{gpu_num}_power_state" try: - response = requests.get(url, timeout=3) + response = requests.get(url, timeout=10) if response.status_code != 200: break data = json.loads(response.text) diff --git a/llm_server/opts.py b/llm_server/opts.py index ece8555..49bf837 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -27,3 +27,7 @@ llm_middleware_name = '' enable_openi_compatible_backend = True openai_system_prompt = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""" expose_openai_system_prompt = True +enable_streaming = True + +backend_request_timeout = 30 +backend_generate_request_timeout = 120 diff --git a/llm_server/routes/cache.py b/llm_server/routes/cache.py index c7a7b80..dd962c0 100644 --- a/llm_server/routes/cache.py +++ b/llm_server/routes/cache.py @@ -1,10 +1,10 @@ import json import sys +import redis as redis_pkg from flask_caching import Cache from redis import Redis from redis.typing import FieldT -import redis as redis_pkg cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local-llm'}) @@ -20,44 +20,45 @@ class RedisWrapper: self.redis = Redis(**kwargs) self.prefix = prefix try: - self.set('check_connected', 1) + self.set('____', 1) except redis_pkg.exceptions.ConnectionError as e: print('Failed to connect to the Redis server:', e) print('Did you install and start Redis?') sys.exit(1) + def _key(self, key): + return f"{self.prefix}:{key}" + def set(self, key, value): - return self.redis.set(f"{self.prefix}:{key}", value) + return self.redis.set(self._key(key), value) def get(self, key): - return self.redis.get(f"{self.prefix}:{key}") + return self.redis.get(self._key(key)) def incr(self, key, amount=1): - return self.redis.incr(f"{self.prefix}:{key}", amount) + return self.redis.incr(self._key(key), amount) def decr(self, key, amount=1): - return self.redis.decr(f"{self.prefix}:{key}", amount) + return self.redis.decr(self._key(key), amount) def sadd(self, key: str, *values: FieldT): - return self.redis.sadd(f"{self.prefix}:{key}", *values) + return self.redis.sadd(self._key(key), *values) def srem(self, key: str, *values: FieldT): - return self.redis.srem(f"{self.prefix}:{key}", *values) + return self.redis.srem(self._key(key), *values) def sismember(self, key: str, value: str): - return self.redis.sismember(f"{self.prefix}:{key}", value) + return self.redis.sismember(self._key(key), value) def set_dict(self, key, dict_value): - # return self.redis.hset(f"{self.prefix}:{key}", mapping=dict_value) - return self.set(f"{self.prefix}:{key}", json.dumps(dict_value)) + return self.set(self._key(key), json.dumps(dict_value)) def get_dict(self, key): - # return self.redis.hgetall(f"{self.prefix}:{key}") - r = self.get(f"{self.prefix}:{key}") + r = self.get(self._key(key)) if not r: return dict() else: - return json.loads(r) + return json.loads(r.decode("utf-8")) def flush(self): flushed = [] diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py index 7f830ed..9577480 100644 --- a/llm_server/routes/ooba_request_handler.py +++ b/llm_server/routes/ooba_request_handler.py @@ -16,45 +16,18 @@ class OobaRequestHandler(RequestHandler): def handle_request(self): if self.used: - raise Exception + raise Exception('Can only use a RequestHandler object once.') - request_valid_json, self.request_json_body = validate_json(self.request) - if not request_valid_json: - return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 - - params_valid, request_valid = self.validate_request() - if not request_valid[0] or not params_valid[0]: - error_messages = [msg for valid, msg in [request_valid, params_valid] if not valid and msg] - combined_error_message = ', '.join(error_messages) - err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error') - log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) - # TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types - return jsonify({ - 'code': 400, - 'msg': 'parameter validation error', - 'results': [{'text': err}] - }), 200 + request_valid, invalid_response = self.validate_request() + if not request_valid: + return invalid_response # Reconstruct the request JSON with the validated parameters and prompt. prompt = self.request_json_body.get('prompt', '') llm_request = {**self.parameters, 'prompt': prompt} - if not self.is_client_ratelimited(): - event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority) - else: - event = None - - if not event: - return self.handle_ratelimited() - - event.wait() - success, response, error_msg = event.data - - end_time = time.time() - elapsed_time = end_time - self.start_time - - self.used = True - return self.backend.handle_response(success, self.request, response, error_msg, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers)) + _, backend_response = self.generate_response(llm_request) + return backend_response def handle_ratelimited(self): backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index 0dfc888..3a56242 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -20,5 +20,4 @@ def openai_chat_completions(): if not request_valid_json or not request_json_body.get('messages'): return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 else: - handler = OpenAIRequestHandler(request) - return handler.handle_request() + return OpenAIRequestHandler(request).handle_request() diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index a26e910..8cb7119 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -1,15 +1,15 @@ import re import time +from typing import Tuple from uuid import uuid4 +import flask import tiktoken from flask import jsonify from llm_server import opts from llm_server.database import log_prompt from llm_server.routes.helpers.client import format_sillytavern_err -from llm_server.routes.helpers.http import validate_json -from llm_server.routes.queue import priority_queue from llm_server.routes.request_handler import RequestHandler tokenizer = tiktoken.get_encoding("cl100k_base") @@ -20,50 +20,22 @@ class OpenAIRequestHandler(RequestHandler): super().__init__(*args, **kwargs) self.prompt = None - def handle_request(self): + def handle_request(self) -> Tuple[flask.Response, int]: if self.used: raise Exception - request_valid_json, self.request_json_body = validate_json(self.request) + request_valid, invalid_response = self.validate_request() + if not request_valid: + return invalid_response + self.prompt = self.transform_messages_to_prompt() - if not request_valid_json or not self.prompt: - return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 - - params_valid, request_valid = self.validate_request() - if not request_valid[0] or not params_valid[0]: - error_messages = [msg for valid, msg in [request_valid, params_valid] if not valid and msg] - combined_error_message = ', '.join(error_messages) - err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error') - log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) - # TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types - return jsonify({ - 'code': 400, - 'msg': 'parameter validation error', - 'results': [{'text': err}] - }), 200 - # Reconstruct the request JSON with the validated parameters and prompt. self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE']) llm_request = {**self.parameters, 'prompt': self.prompt} - if not self.is_client_ratelimited(): - event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority) - else: - event = None - - if not event: - return self.handle_ratelimited() - - event.wait() - success, backend_response, error_msg = event.data - - end_time = time.time() - elapsed_time = end_time - self.start_time - - self.used = True - response, response_status_code = self.backend.handle_response(success=success, request=self.request, response=backend_response, error_msg=error_msg, client_ip=self.client_ip, token=self.token, prompt=self.prompt, elapsed_time=elapsed_time, parameters=self.parameters, headers=dict(self.request.headers)) - return build_openai_response(self.prompt, response.json['results'][0]['text']), 200 + _, (backend_response, backend_response_status_code) = self.generate_response(llm_request) + return build_openai_response(self.prompt, backend_response.json['results'][0]['text']), backend_response_status_code def handle_ratelimited(self): backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') @@ -124,9 +96,3 @@ def build_openai_response(prompt, response): "total_tokens": prompt_tokens + response_tokens } }) - -# def transform_prompt_to_text(prompt: list): -# text = '' -# for item in prompt: -# text += item['content'] + '\n' -# return text.strip('\n') diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index c485bab..0550df8 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -1,13 +1,18 @@ import sqlite3 import time -from typing import Union +from typing import Tuple, Union import flask +from flask import Response, jsonify from llm_server import opts +from llm_server.database import log_prompt from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.cache import redis +from llm_server.routes.helpers.client import format_sillytavern_err +from llm_server.routes.helpers.http import validate_json +from llm_server.routes.queue import priority_queue from llm_server.routes.stats import SemaphoreCheckerThread DEFAULT_PRIORITY = 9999 @@ -15,8 +20,8 @@ DEFAULT_PRIORITY = 9999 class RequestHandler: def __init__(self, incoming_request: flask.Request): - self.request_json_body = None self.request = incoming_request + _, self.request_json_body = validate_json(self.request) # routes need to validate it, here we just load it self.start_time = time.time() self.client_ip = self.get_client_ip() self.token = self.request.headers.get('X-Api-Key') @@ -51,27 +56,103 @@ class RequestHandler: self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens') self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body) - def validate_request(self): + def validate_request(self) -> Tuple[bool, Tuple[Response | None, int]]: self.load_parameters() params_valid = False request_valid = False - invalid_request_err_msg = None if self.parameters: params_valid = True request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters) - return (params_valid, self.parameters_invalid_msg), (request_valid, invalid_request_err_msg) - def is_client_ratelimited(self): + if not request_valid or not params_valid: + error_messages = [msg for valid, msg in [request_valid, params_valid] if not valid and msg] + combined_error_message = ', '.join(error_messages) + err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error') + log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) + # TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types + return False, (jsonify({ + 'code': 400, + 'msg': 'parameter validation error', + 'results': [{'text': err}] + }), 200) + return True, (None, 0) + + def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]: + if not self.is_client_ratelimited(): + event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority) + else: + event = None + + if not event: + return (False, None, None, 0), self.handle_ratelimited() + + prompt = llm_request['prompt'] + + event.wait() + success, response, error_msg = event.data + + end_time = time.time() + elapsed_time = end_time - self.start_time + + if response: + try: + # Be extra careful when getting attributes from the response object + response_status_code = response.status_code + except: + response_status_code = 0 + else: + response_status_code = None + + # =============================================== + + # We encountered an error + if not success or not response or error_msg: + if not error_msg or error_msg == '': + error_msg = 'Unknown error.' + else: + error_msg = error_msg.strip('.') + '.' + backend_response = format_sillytavern_err(error_msg, 'error') + log_prompt(self.client_ip, self.token, prompt, backend_response, None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) + return (False, None, None, 0), (jsonify({ + 'code': 500, + 'msg': error_msg, + 'results': [{'text': backend_response}] + }), 200) + + # =============================================== + + response_valid_json, response_json_body = validate_json(response) + + # The backend didn't send valid JSON + if not response_valid_json: + error_msg = 'The backend did not return valid JSON.' + backend_response = format_sillytavern_err(error_msg, 'error') + log_prompt(self.client_ip, self.token, prompt, backend_response, elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) + return (False, None, None, 0), (jsonify({ + 'code': 500, + 'msg': error_msg, + 'results': [{'text': backend_response}] + }), 200) + + # =============================================== + + self.used = True + return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers)) + + def is_client_ratelimited(self) -> bool: queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0) if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0: return False else: return True - def handle_request(self): + def handle_request(self) -> Tuple[flask.Response, int]: + # Must include this in your child. + # if self.used: + # raise Exception('Can only use a RequestHandler object once.') raise NotImplementedError - def handle_ratelimited(self): + def handle_ratelimited(self) -> Tuple[flask.Response, int]: raise NotImplementedError diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index 6149a70..783997a 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -82,7 +82,7 @@ def generate_stats(): 'online': online, 'endpoints': { 'blocking': f'https://{opts.base_client_api}', - 'streaming': f'wss://{opts.base_client_api}/v1/stream', + 'streaming': f'wss://{opts.base_client_api}/v1/stream' if opts.enable_streaming else None, }, 'queue': { 'processing': active_gen_workers, diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 92c9bec..11940c7 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -15,7 +15,10 @@ from ...stream import sock @sock.route('/api/v1/stream') # TODO: use blueprint route??? def stream(ws): - return 'disabled', 401 + if not opts.enable_streaming: + # TODO: return a formatted ST error message + return 'disabled', 401 + # start_time = time.time() # if request.headers.get('cf-connecting-ip'): # client_ip = request.headers.get('cf-connecting-ip') diff --git a/llm_server/threads.py b/llm_server/threads.py index baf3da5..1529165 100644 --- a/llm_server/threads.py +++ b/llm_server/threads.py @@ -25,14 +25,6 @@ class MainBackgroundThread(Thread): def run(self): while True: if opts.mode == 'oobabooga': - # try: - # r = requests.get(f'{opts.backend_url}/api/v1/model', timeout=3, verify=opts.verify_ssl) - # opts.running_model = r.json()['result'] - # redis.set('backend_online', 1) - # except Exception as e: - # redis.set('backend_online', 0) - # # TODO: handle error - # print(e) model, err = get_running_model() if err: print(err) @@ -52,7 +44,7 @@ class MainBackgroundThread(Thread): raise Exception # exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0 - # was entered into the column. The new code enters null instead but we need to be backwards compatible for now + # was entered into the column. The new code enters null instead but we need to be backwards compatible for now. average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model, opts.mode, opts.backend_url, exclude_zeros=True) or 0 redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec) @@ -65,7 +57,6 @@ class MainBackgroundThread(Thread): # overall = average_column_for_model('prompts', 'response_tokens', opts.running_model) # print(f'Weighted: {average_output_tokens}, overall: {overall}') - # Avoid division by zero - estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 + estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero redis.set('estimated_avg_tps', estimated_avg_tps) time.sleep(60) diff --git a/other/vllm/README.md b/other/vllm/README.md new file mode 100644 index 0000000..eb59ae8 --- /dev/null +++ b/other/vllm/README.md @@ -0,0 +1,9 @@ +### Nginx + +1. Make sure your proxies all have a long timeout: +``` +proxy_read_timeout 300; +proxy_connect_timeout 300; +proxy_send_timeout 300; +``` +The LLM middleware has a request timeout of 120 so this longer timeout is to avoid any issues. \ No newline at end of file diff --git a/server.py b/server.py index 5446db6..7a600fb 100644 --- a/server.py +++ b/server.py @@ -73,6 +73,7 @@ opts.llm_middleware_name = config['llm_middleware_name'] opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend'] opts.openai_system_prompt = config['openai_system_prompt'] opts.expose_openai_system_prompt = config['expose_openai_system_prompt'] +opts.enable_streaming = config['enable_streaming'] opts.verify_ssl = config['verify_ssl'] if not opts.verify_ssl: @@ -107,8 +108,6 @@ app = Flask(__name__) cache.init_app(app) cache.clear() # clear redis cache init_socketio(app) -# with app.app_context(): -# current_app.tokenizer = tiktoken.get_encoding("cl100k_base") app.register_blueprint(bp, url_prefix='/api/v1/') app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/') @@ -118,7 +117,8 @@ app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/') @app.route('/') @app.route('/api') -@cache.cached(timeout=10, query_string=True) +@app.route('/api/openai') +@cache.cached(timeout=60) def home(): if not opts.base_client_api: opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}' @@ -165,7 +165,8 @@ def home(): stats_json=json.dumps(stats, indent=4, ensure_ascii=False), extra_info=mode_info, openai_client_api=f'https://{opts.base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled', - expose_openai_system_prompt=opts.expose_openai_system_prompt + expose_openai_system_prompt=opts.expose_openai_system_prompt, + enable_streaming=opts.enable_streaming, ) diff --git a/templates/home.html b/templates/home.html index 3acb10e..ffebc11 100644 --- a/templates/home.html +++ b/templates/home.html @@ -77,7 +77,7 @@

Estimated Wait Time: {{ estimated_wait }}


Client API URL: {{ client_api }}

-

Streaming API URL: {{ ws_client_api }}

+

Streaming API URL: {{ ws_client_api if enable_streaming else 'Disabled' }}

OpenAI-Compatible API URL: {{ openai_client_api }}

{% if info_html|length > 1 %}
@@ -93,8 +93,7 @@
  1. Set your API type to {{ mode_name }}
  2. Enter {{ client_api }} in the {{ api_input_textbox }} textbox.
  3. -
  4. Enter {{ ws_client_api }} in the {{ streaming_input_textbox }} textbox. -
  5. + {% if enable_streaming %}
  6. Enter {{ ws_client_api }} in the {{ streaming_input_textbox }} textbox.
  7. {% endif %}
  8. If you have a token, check the Mancer AI checkbox and enter your token in the Mancer API key textbox.