fix error handling
This commit is contained in:
parent
90bb68115f
commit
957a6cd092
|
@ -35,7 +35,6 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
|
|||
increment_token_uses(token)
|
||||
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
|
||||
timestamp = int(time.time())
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
|
|
|
@ -4,6 +4,7 @@ import flask
|
|||
|
||||
from llm_server import opts
|
||||
from llm_server.llm import get_token_count
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
|
||||
class LLMBackend:
|
||||
|
@ -39,5 +40,6 @@ class LLMBackend:
|
|||
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
|
||||
prompt_len = get_token_count(prompt)
|
||||
if prompt_len > opts.context_size - 10:
|
||||
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}). Please lower your context size'
|
||||
model_name = redis.get('running_model', str, 'NO MODEL ERROR')
|
||||
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}, model: {model_name}). Please lower your context size'
|
||||
return True, None
|
||||
|
|
|
@ -3,7 +3,6 @@ This file is used by the worker that processes requests.
|
|||
"""
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
@ -87,17 +86,22 @@ def handle_blocking_request(json_data: dict):
|
|||
try:
|
||||
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
except requests.exceptions.ReadTimeout:
|
||||
print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
|
||||
return False, None, 'Request to backend timed out'
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
|
||||
return False, None, 'Request to backend encountered error'
|
||||
if r.status_code != 200:
|
||||
print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
|
||||
return False, r, f'Backend returned {r.status_code}'
|
||||
return True, r, None
|
||||
|
||||
|
||||
def generate(json_data: dict):
|
||||
if json_data.get('stream'):
|
||||
return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
try:
|
||||
return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
except Exception as e:
|
||||
print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
|
||||
else:
|
||||
return handle_blocking_request(json_data)
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import traceback
|
||||
|
||||
import requests
|
||||
import tiktoken
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
|
||||
def tokenize(prompt: str) -> int:
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
if not prompt:
|
||||
|
@ -14,6 +13,6 @@ def tokenize(prompt: str) -> int:
|
|||
r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
j = r.json()
|
||||
return j['length']
|
||||
except:
|
||||
traceback.print_exc()
|
||||
except Exception as e:
|
||||
print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}')
|
||||
return len(tokenizer.encode(prompt)) + 10
|
||||
|
|
|
@ -42,4 +42,4 @@ class OobaRequestHandler(RequestHandler):
|
|||
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
|
||||
return jsonify({
|
||||
'results': [{'text': msg}]
|
||||
}), 400
|
||||
}), 200 # return 200 so we don't trigger an error message in the client's ST
|
||||
|
|
|
@ -81,7 +81,6 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
return build_openai_response(self.prompt, backend_response), 429
|
||||
|
||||
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
|
||||
print(msg)
|
||||
# return build_openai_response('', msg), 400
|
||||
return jsonify({
|
||||
"error": {
|
||||
|
|
|
@ -125,8 +125,9 @@ class RequestHandler:
|
|||
backend_response = (Response(msg, 400), 400)
|
||||
else:
|
||||
backend_response = self.handle_error(format_sillytavern_err(msg, 'error'))
|
||||
|
||||
if do_log:
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
||||
return False, backend_response
|
||||
return True, (None, 0)
|
||||
|
||||
|
@ -172,8 +173,8 @@ class RequestHandler:
|
|||
if disable_st_error_formatting:
|
||||
backend_response = (Response(error_msg, 400), 400)
|
||||
else:
|
||||
backend_response = format_sillytavern_err(error_msg, 'error')
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response, None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
|
||||
backend_response = self.handle_error(format_sillytavern_err(error_msg, 'error'))
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
|
||||
return (False, None, None, 0), backend_response
|
||||
|
||||
# ===============================================
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as build
|
||||
|
||||
RUN apt-get update && apt-get install -y git python3-pip python3-venv wget unzip && rm -rf /var/lib/apt/lists/*
|
||||
RUN apt-get update && \
|
||||
apt-get install -y git python3-pip python3-venv wget unzip && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
RUN pip3 install --upgrade pip setuptools wheel
|
||||
|
||||
RUN git clone https://git.evulid.cc/cyberes/local-llm-server.git /local-llm-server
|
||||
|
@ -33,7 +35,10 @@ RUN apt-get update && apt-get install -y supervisor && rm -rf /var/lib/apt/lists
|
|||
RUN useradd -ms /bin/bash apiserver
|
||||
RUN usermod -s /bin/bash root
|
||||
|
||||
RUN apt-get update && apt-get install -y python3 python3-pip wget aria2 git-lfs git openssh-server openssh-client
|
||||
RUN apt-get update && \
|
||||
apt-get install -y python3 python3-pip wget aria2 git-lfs git openssh-server openssh-client nano tmux && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip3 install --upgrade pip setuptools wheel
|
||||
RUN pip3 install glances
|
||||
|
||||
|
|
|
@ -25,10 +25,13 @@ from llm_server.routes.server_error import handle_server_error
|
|||
from llm_server.routes.v1 import bp
|
||||
from llm_server.stream import init_socketio
|
||||
|
||||
# TODO: have the workers handle streaming too
|
||||
# TODO: return 200 when returning formatted sillytavern error
|
||||
# TODO: add some sort of loadbalancer to send requests to a group of backends
|
||||
# TODO: allow setting concurrent gens per-backend
|
||||
# TODO: use first backend as default backend
|
||||
|
||||
# TODO: simulate OpenAI error messages regardless of endpoint
|
||||
# TODO: allow setting specific simoltaneous IPs allowed per token
|
||||
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
|
||||
# TODO: unify logging thread in a function and use async/await instead
|
||||
|
|
Reference in New Issue