diff --git a/llm_server/database/database.py b/llm_server/database/database.py index ae1e70d..f307c2e 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -35,7 +35,6 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe increment_token_uses(token) running_model = redis.get('running_model', str, 'ERROR') - timestamp = int(time.time()) cursor = database.cursor() try: diff --git a/llm_server/llm/llm_backend.py b/llm_server/llm/llm_backend.py index 0d8ec27..1c11c17 100644 --- a/llm_server/llm/llm_backend.py +++ b/llm_server/llm/llm_backend.py @@ -4,6 +4,7 @@ import flask from llm_server import opts from llm_server.llm import get_token_count +from llm_server.routes.cache import redis class LLMBackend: @@ -39,5 +40,6 @@ class LLMBackend: def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]: prompt_len = get_token_count(prompt) if prompt_len > opts.context_size - 10: - return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}). Please lower your context size' + model_name = redis.get('running_model', str, 'NO MODEL ERROR') + return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}, model: {model_name}). Please lower your context size' return True, None diff --git a/llm_server/llm/vllm/generate.py b/llm_server/llm/vllm/generate.py index 86a27ac..1549f2e 100644 --- a/llm_server/llm/vllm/generate.py +++ b/llm_server/llm/vllm/generate.py @@ -3,7 +3,6 @@ This file is used by the worker that processes requests. """ import json import time -import traceback from uuid import uuid4 import requests @@ -87,17 +86,22 @@ def handle_blocking_request(json_data: dict): try: r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) except requests.exceptions.ReadTimeout: + print(f'Failed to reach VLLM inference endpoint - request to backend timed out') return False, None, 'Request to backend timed out' except Exception as e: - traceback.print_exc() + print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}') return False, None, 'Request to backend encountered error' if r.status_code != 200: + print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}') return False, r, f'Backend returned {r.status_code}' return True, r, None def generate(json_data: dict): if json_data.get('stream'): - return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) + try: + return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) + except Exception as e: + print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}') else: return handle_blocking_request(json_data) diff --git a/llm_server/llm/vllm/tokenize.py b/llm_server/llm/vllm/tokenize.py index 5a8d09a..6bb9343 100644 --- a/llm_server/llm/vllm/tokenize.py +++ b/llm_server/llm/vllm/tokenize.py @@ -1,10 +1,9 @@ -import traceback - import requests import tiktoken from llm_server import opts + def tokenize(prompt: str) -> int: tokenizer = tiktoken.get_encoding("cl100k_base") if not prompt: @@ -14,6 +13,6 @@ def tokenize(prompt: str) -> int: r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) j = r.json() return j['length'] - except: - traceback.print_exc() + except Exception as e: + print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}') return len(tokenizer.encode(prompt)) + 10 diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py index 11ab1e8..1f186f0 100644 --- a/llm_server/routes/ooba_request_handler.py +++ b/llm_server/routes/ooba_request_handler.py @@ -42,4 +42,4 @@ class OobaRequestHandler(RequestHandler): def handle_error(self, msg: str) -> Tuple[flask.Response, int]: return jsonify({ 'results': [{'text': msg}] - }), 400 + }), 200 # return 200 so we don't trigger an error message in the client's ST diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index 89832fc..1aedf26 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -81,7 +81,6 @@ class OpenAIRequestHandler(RequestHandler): return build_openai_response(self.prompt, backend_response), 429 def handle_error(self, msg: str) -> Tuple[flask.Response, int]: - print(msg) # return build_openai_response('', msg), 400 return jsonify({ "error": { diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index 16e3522..90ca620 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -125,8 +125,9 @@ class RequestHandler: backend_response = (Response(msg, 400), 400) else: backend_response = self.handle_error(format_sillytavern_err(msg, 'error')) + if do_log: - log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) + log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) return False, backend_response return True, (None, 0) @@ -172,8 +173,8 @@ class RequestHandler: if disable_st_error_formatting: backend_response = (Response(error_msg, 400), 400) else: - backend_response = format_sillytavern_err(error_msg, 'error') - log_prompt(self.client_ip, self.token, prompt, backend_response, None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) + backend_response = self.handle_error(format_sillytavern_err(error_msg, 'error')) + log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) return (False, None, None, 0), backend_response # =============================================== diff --git a/other/vllm/Docker/Dockerfile b/other/vllm/Docker/Dockerfile index d5be5d7..afd23f7 100644 --- a/other/vllm/Docker/Dockerfile +++ b/other/vllm/Docker/Dockerfile @@ -1,6 +1,8 @@ FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as build -RUN apt-get update && apt-get install -y git python3-pip python3-venv wget unzip && rm -rf /var/lib/apt/lists/* +RUN apt-get update && \ + apt-get install -y git python3-pip python3-venv wget unzip && \ + rm -rf /var/lib/apt/lists/* RUN pip3 install --upgrade pip setuptools wheel RUN git clone https://git.evulid.cc/cyberes/local-llm-server.git /local-llm-server @@ -33,7 +35,10 @@ RUN apt-get update && apt-get install -y supervisor && rm -rf /var/lib/apt/lists RUN useradd -ms /bin/bash apiserver RUN usermod -s /bin/bash root -RUN apt-get update && apt-get install -y python3 python3-pip wget aria2 git-lfs git openssh-server openssh-client +RUN apt-get update && \ + apt-get install -y python3 python3-pip wget aria2 git-lfs git openssh-server openssh-client nano tmux && \ + rm -rf /var/lib/apt/lists/* + RUN pip3 install --upgrade pip setuptools wheel RUN pip3 install glances diff --git a/server.py b/server.py index 3f56cbb..39a2eaa 100644 --- a/server.py +++ b/server.py @@ -25,10 +25,13 @@ from llm_server.routes.server_error import handle_server_error from llm_server.routes.v1 import bp from llm_server.stream import init_socketio +# TODO: have the workers handle streaming too +# TODO: return 200 when returning formatted sillytavern error # TODO: add some sort of loadbalancer to send requests to a group of backends # TODO: allow setting concurrent gens per-backend # TODO: use first backend as default backend +# TODO: simulate OpenAI error messages regardless of endpoint # TODO: allow setting specific simoltaneous IPs allowed per token # TODO: make sure log_prompt() is used everywhere, including errors and invalid requests # TODO: unify logging thread in a function and use async/await instead