fix error handling

This commit is contained in:
Cyberes 2023-09-27 14:36:49 -06:00
parent 90bb68115f
commit 957a6cd092
9 changed files with 28 additions and 16 deletions

View File

@ -35,7 +35,6 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
increment_token_uses(token) increment_token_uses(token)
running_model = redis.get('running_model', str, 'ERROR') running_model = redis.get('running_model', str, 'ERROR')
timestamp = int(time.time()) timestamp = int(time.time())
cursor = database.cursor() cursor = database.cursor()
try: try:

View File

@ -4,6 +4,7 @@ import flask
from llm_server import opts from llm_server import opts
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
from llm_server.routes.cache import redis
class LLMBackend: class LLMBackend:
@ -39,5 +40,6 @@ class LLMBackend:
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]: def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
prompt_len = get_token_count(prompt) prompt_len = get_token_count(prompt)
if prompt_len > opts.context_size - 10: if prompt_len > opts.context_size - 10:
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}). Please lower your context size' model_name = redis.get('running_model', str, 'NO MODEL ERROR')
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}, model: {model_name}). Please lower your context size'
return True, None return True, None

View File

@ -3,7 +3,6 @@ This file is used by the worker that processes requests.
""" """
import json import json
import time import time
import traceback
from uuid import uuid4 from uuid import uuid4
import requests import requests
@ -87,17 +86,22 @@ def handle_blocking_request(json_data: dict):
try: try:
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
except requests.exceptions.ReadTimeout: except requests.exceptions.ReadTimeout:
print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
return False, None, 'Request to backend timed out' return False, None, 'Request to backend timed out'
except Exception as e: except Exception as e:
traceback.print_exc() print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
return False, None, 'Request to backend encountered error' return False, None, 'Request to backend encountered error'
if r.status_code != 200: if r.status_code != 200:
print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
return False, r, f'Backend returned {r.status_code}' return False, r, f'Backend returned {r.status_code}'
return True, r, None return True, r, None
def generate(json_data: dict): def generate(json_data: dict):
if json_data.get('stream'): if json_data.get('stream'):
return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) try:
return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
except Exception as e:
print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
else: else:
return handle_blocking_request(json_data) return handle_blocking_request(json_data)

View File

@ -1,10 +1,9 @@
import traceback
import requests import requests
import tiktoken import tiktoken
from llm_server import opts from llm_server import opts
def tokenize(prompt: str) -> int: def tokenize(prompt: str) -> int:
tokenizer = tiktoken.get_encoding("cl100k_base") tokenizer = tiktoken.get_encoding("cl100k_base")
if not prompt: if not prompt:
@ -14,6 +13,6 @@ def tokenize(prompt: str) -> int:
r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
j = r.json() j = r.json()
return j['length'] return j['length']
except: except Exception as e:
traceback.print_exc() print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}')
return len(tokenizer.encode(prompt)) + 10 return len(tokenizer.encode(prompt)) + 10

View File

@ -42,4 +42,4 @@ class OobaRequestHandler(RequestHandler):
def handle_error(self, msg: str) -> Tuple[flask.Response, int]: def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
return jsonify({ return jsonify({
'results': [{'text': msg}] 'results': [{'text': msg}]
}), 400 }), 200 # return 200 so we don't trigger an error message in the client's ST

View File

@ -81,7 +81,6 @@ class OpenAIRequestHandler(RequestHandler):
return build_openai_response(self.prompt, backend_response), 429 return build_openai_response(self.prompt, backend_response), 429
def handle_error(self, msg: str) -> Tuple[flask.Response, int]: def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
print(msg)
# return build_openai_response('', msg), 400 # return build_openai_response('', msg), 400
return jsonify({ return jsonify({
"error": { "error": {

View File

@ -125,8 +125,9 @@ class RequestHandler:
backend_response = (Response(msg, 400), 400) backend_response = (Response(msg, 400), 400)
else: else:
backend_response = self.handle_error(format_sillytavern_err(msg, 'error')) backend_response = self.handle_error(format_sillytavern_err(msg, 'error'))
if do_log: if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
return False, backend_response return False, backend_response
return True, (None, 0) return True, (None, 0)
@ -172,8 +173,8 @@ class RequestHandler:
if disable_st_error_formatting: if disable_st_error_formatting:
backend_response = (Response(error_msg, 400), 400) backend_response = (Response(error_msg, 400), 400)
else: else:
backend_response = format_sillytavern_err(error_msg, 'error') backend_response = self.handle_error(format_sillytavern_err(error_msg, 'error'))
log_prompt(self.client_ip, self.token, prompt, backend_response, None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
return (False, None, None, 0), backend_response return (False, None, None, 0), backend_response
# =============================================== # ===============================================

View File

@ -1,6 +1,8 @@
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as build FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as build
RUN apt-get update && apt-get install -y git python3-pip python3-venv wget unzip && rm -rf /var/lib/apt/lists/* RUN apt-get update && \
apt-get install -y git python3-pip python3-venv wget unzip && \
rm -rf /var/lib/apt/lists/*
RUN pip3 install --upgrade pip setuptools wheel RUN pip3 install --upgrade pip setuptools wheel
RUN git clone https://git.evulid.cc/cyberes/local-llm-server.git /local-llm-server RUN git clone https://git.evulid.cc/cyberes/local-llm-server.git /local-llm-server
@ -33,7 +35,10 @@ RUN apt-get update && apt-get install -y supervisor && rm -rf /var/lib/apt/lists
RUN useradd -ms /bin/bash apiserver RUN useradd -ms /bin/bash apiserver
RUN usermod -s /bin/bash root RUN usermod -s /bin/bash root
RUN apt-get update && apt-get install -y python3 python3-pip wget aria2 git-lfs git openssh-server openssh-client RUN apt-get update && \
apt-get install -y python3 python3-pip wget aria2 git-lfs git openssh-server openssh-client nano tmux && \
rm -rf /var/lib/apt/lists/*
RUN pip3 install --upgrade pip setuptools wheel RUN pip3 install --upgrade pip setuptools wheel
RUN pip3 install glances RUN pip3 install glances

View File

@ -25,10 +25,13 @@ from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp from llm_server.routes.v1 import bp
from llm_server.stream import init_socketio from llm_server.stream import init_socketio
# TODO: have the workers handle streaming too
# TODO: return 200 when returning formatted sillytavern error
# TODO: add some sort of loadbalancer to send requests to a group of backends # TODO: add some sort of loadbalancer to send requests to a group of backends
# TODO: allow setting concurrent gens per-backend # TODO: allow setting concurrent gens per-backend
# TODO: use first backend as default backend # TODO: use first backend as default backend
# TODO: simulate OpenAI error messages regardless of endpoint
# TODO: allow setting specific simoltaneous IPs allowed per token # TODO: allow setting specific simoltaneous IPs allowed per token
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests # TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
# TODO: unify logging thread in a function and use async/await instead # TODO: unify logging thread in a function and use async/await instead