option to disable streaming, improve timeout on requests to backend, fix error handling. reduce duplicate code, misc other cleanup
This commit is contained in:
parent
e79b206e1a
commit
79b1e01b61
|
@ -17,6 +17,7 @@ config_default_vars = {
|
|||
'show_backend_info': True,
|
||||
'max_new_tokens': 500,
|
||||
'manual_model_name': False,
|
||||
'enable_streaming': True,
|
||||
'enable_openi_compatible_backend': True,
|
||||
'expose_openai_system_prompt': True,
|
||||
'openai_system_prompt': """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""",
|
||||
|
|
|
@ -9,14 +9,14 @@ def get_running_model():
|
|||
|
||||
if opts.mode == 'oobabooga':
|
||||
try:
|
||||
backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=10, verify=opts.verify_ssl)
|
||||
backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
|
||||
r_json = backend_response.json()
|
||||
return r_json['result'], None
|
||||
except Exception as e:
|
||||
return False, e
|
||||
elif opts.mode == 'vllm':
|
||||
try:
|
||||
backend_response = requests.get(f'{opts.backend_url}/model', timeout=10, verify=opts.verify_ssl)
|
||||
backend_response = requests.get(f'{opts.backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
|
||||
r_json = backend_response.json()
|
||||
return r_json['model'], None
|
||||
except Exception as e:
|
||||
|
|
|
@ -6,7 +6,7 @@ import flask
|
|||
class LLMBackend:
|
||||
default_params: dict
|
||||
|
||||
def handle_response(self, success, request: flask.Request, response: flask.Response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
|
||||
def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers):
|
||||
raise NotImplementedError
|
||||
|
||||
def validate_params(self, params_dict: dict) -> Tuple[bool, str | None]:
|
||||
|
@ -24,5 +24,5 @@ class LLMBackend:
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def validate_request(self, parameters: dict) -> (bool, Union[str, None]):
|
||||
def validate_request(self, parameters: dict) -> Tuple[bool, Union[str, None]]:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -9,9 +9,11 @@ from llm_server import opts
|
|||
|
||||
def generate(json_data: dict):
|
||||
try:
|
||||
r = requests.post(f'{opts.backend_url}/api/v1/generate', json=json_data, verify=opts.verify_ssl, timeout=120)
|
||||
r = requests.post(f'{opts.backend_url}/api/v1/generate', json=json_data, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
except requests.exceptions.ReadTimeout:
|
||||
return False, None, 'Request to backend timed out'
|
||||
except Exception as e:
|
||||
return False, None, f'{e.__class__.__name__}: {e}'
|
||||
return False, None, 'Request to backend encountered error'
|
||||
if r.status_code != 200:
|
||||
return False, r, f'Backend returned {r.status_code}'
|
||||
return True, r, None
|
||||
|
|
|
@ -9,25 +9,38 @@ from ...routes.helpers.http import validate_json
|
|||
|
||||
|
||||
class OobaboogaBackend(LLMBackend):
|
||||
default_params = {}
|
||||
|
||||
def handle_response(self, success, request, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
|
||||
raise NotImplementedError('need to implement default_params')
|
||||
|
||||
backend_err = False
|
||||
response_valid_json, response_json_body = validate_json(response)
|
||||
try:
|
||||
# Be extra careful when getting attributes from the response object
|
||||
response_status_code = response.status_code
|
||||
except:
|
||||
response_status_code = 0
|
||||
if response:
|
||||
try:
|
||||
# Be extra careful when getting attributes from the response object
|
||||
response_status_code = response.status_code
|
||||
except:
|
||||
response_status_code = 0
|
||||
else:
|
||||
response_status_code = None
|
||||
|
||||
# ===============================================
|
||||
|
||||
# We encountered an error
|
||||
if not success or not response:
|
||||
backend_response = format_sillytavern_err(f'Failed to reach the backend (oobabooga): {error_msg}', 'error')
|
||||
log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response if response else 0, request.url, is_error=True)
|
||||
if not success or not response or error_msg:
|
||||
if not error_msg or error_msg == '':
|
||||
error_msg = 'Unknown error.'
|
||||
else:
|
||||
error_msg = error_msg.strip('.') + '.'
|
||||
backend_response = format_sillytavern_err(error_msg, 'error')
|
||||
log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True)
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
'msg': 'failed to reach backend',
|
||||
'msg': error_msg,
|
||||
'results': [{'text': backend_response}]
|
||||
}), 200
|
||||
|
||||
# ===============================================
|
||||
|
||||
if response_valid_json:
|
||||
|
@ -60,14 +73,6 @@ class OobaboogaBackend(LLMBackend):
|
|||
# No validation required
|
||||
return True, None
|
||||
|
||||
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
|
||||
# try:
|
||||
# backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=3, verify=opts.verify_ssl)
|
||||
# r_json = backend_response.json()
|
||||
# return r_json['result'], None
|
||||
# except Exception as e:
|
||||
# return False, e
|
||||
|
||||
def get_parameters(self, parameters):
|
||||
del parameters['prompt']
|
||||
return parameters
|
||||
|
|
|
@ -79,25 +79,13 @@ def transform_prompt_to_text(prompt: list):
|
|||
|
||||
def handle_blocking_request(json_data: dict):
|
||||
try:
|
||||
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=120)
|
||||
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
except requests.exceptions.ReadTimeout:
|
||||
return False, None, 'Request to backend timed out'
|
||||
except Exception as e:
|
||||
return False, None, f'{e.__class__.__name__}: {e}'
|
||||
|
||||
# TODO: check for error here?
|
||||
# response_json = r.json()
|
||||
# response_json['error'] = False
|
||||
|
||||
# new_response = Response()
|
||||
# new_response.status_code = r.status_code
|
||||
# new_response._content = json.dumps(response_json).encode('utf-8')
|
||||
# new_response.raw = io.BytesIO(new_response._content)
|
||||
# new_response.headers = r.headers
|
||||
# new_response.url = r.url
|
||||
# new_response.reason = r.reason
|
||||
# new_response.cookies = r.cookies
|
||||
# new_response.elapsed = r.elapsed
|
||||
# new_response.request = r.request
|
||||
|
||||
return False, None, 'Request to backend encountered error' # f'{e.__class__.__name__}: {e}'
|
||||
if r.status_code != 200:
|
||||
return False, r, f'Backend returned {r.status_code}'
|
||||
return True, r, None
|
||||
|
||||
|
||||
|
|
|
@ -6,80 +6,22 @@ from vllm import SamplingParams
|
|||
from llm_server import opts
|
||||
from llm_server.database import log_prompt
|
||||
from llm_server.llm.llm_backend import LLMBackend
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.helpers.http import validate_json
|
||||
|
||||
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py
|
||||
|
||||
# TODO: https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/proxy/middleware/common.ts?ref_type=heads#L69
|
||||
|
||||
class VLLMBackend(LLMBackend):
|
||||
default_params = vars(SamplingParams())
|
||||
|
||||
def handle_response(self, success, request, response, error_msg, client_ip, token, prompt: str, elapsed_time, parameters, headers):
|
||||
response_valid_json, response_json_body = validate_json(response)
|
||||
backend_err = False
|
||||
try:
|
||||
response_status_code = response.status_code
|
||||
except:
|
||||
response_status_code = 0
|
||||
|
||||
if response_valid_json:
|
||||
if len(response_json_body.get('text', [])):
|
||||
# Does vllm return the prompt and the response together???
|
||||
backend_response = response_json_body['text'][0].split(prompt)[1].strip(' ').strip('\n')
|
||||
else:
|
||||
# Failsafe
|
||||
backend_response = ''
|
||||
|
||||
# TODO: how to detect an error?
|
||||
# if backend_response == '':
|
||||
# backend_err = True
|
||||
# backend_response = format_sillytavern_err(
|
||||
# f'Backend (vllm-gptq) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
|
||||
# f'HTTP CODE {response_status_code}'
|
||||
# )
|
||||
|
||||
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time if not backend_err else None, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
|
||||
return jsonify({'results': [{'text': backend_response}]}), 200
|
||||
def handle_response(self, success, request, response_json_body, response_status_code, client_ip, token, prompt: str, elapsed_time, parameters, headers):
|
||||
if len(response_json_body.get('text', [])):
|
||||
# Does vllm return the prompt and the response together???
|
||||
backend_response = response_json_body['text'][0].split(prompt)[1].strip(' ').strip('\n')
|
||||
else:
|
||||
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
|
||||
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code if response else None, request.url, is_error=True)
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
'msg': 'the backend did not return valid JSON',
|
||||
'results': [{'text': backend_response}]
|
||||
}), 200
|
||||
|
||||
# def validate_params(self, params_dict: dict):
|
||||
# self.default_params = SamplingParams()
|
||||
# try:
|
||||
# sampling_params = SamplingParams(
|
||||
# temperature=params_dict.get('temperature', self.default_paramstemperature),
|
||||
# top_p=params_dict.get('top_p', self.default_paramstop_p),
|
||||
# top_k=params_dict.get('top_k', self.default_paramstop_k),
|
||||
# use_beam_search=True if params_dict['num_beams'] > 1 else False,
|
||||
# length_penalty=params_dict.get('length_penalty', self.default_paramslength_penalty),
|
||||
# early_stopping=params_dict.get('early_stopping', self.default_paramsearly_stopping),
|
||||
# stop=params_dict.get('stopping_strings', self.default_paramsstop),
|
||||
# ignore_eos=params_dict.get('ban_eos_token', False),
|
||||
# max_tokens=params_dict.get('max_new_tokens', self.default_paramsmax_tokens)
|
||||
# )
|
||||
# except ValueError as e:
|
||||
# print(e)
|
||||
# return False, e
|
||||
# return True, None
|
||||
|
||||
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
|
||||
# try:
|
||||
# backend_response = requests.get(f'{opts.backend_url}/api/v1/models', timeout=3, verify=opts.verify_ssl)
|
||||
# r_json = backend_response.json()
|
||||
# model_path = Path(r_json['data'][0]['root']).name
|
||||
# r_json['data'][0]['root'] = model_path
|
||||
# return r_json, None
|
||||
# except Exception as e:
|
||||
# return False, e
|
||||
# Failsafe
|
||||
backend_response = ''
|
||||
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
|
||||
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
|
||||
return jsonify({'results': [{'text': backend_response}]}), 200
|
||||
|
||||
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
|
||||
try:
|
||||
|
|
|
@ -12,7 +12,7 @@ def get_power_states():
|
|||
while True:
|
||||
url = f"{opts.netdata_root}/api/v1/data?chart=nvidia_smi.gpu{gpu_num}_power_state"
|
||||
try:
|
||||
response = requests.get(url, timeout=3)
|
||||
response = requests.get(url, timeout=10)
|
||||
if response.status_code != 200:
|
||||
break
|
||||
data = json.loads(response.text)
|
||||
|
|
|
@ -27,3 +27,7 @@ llm_middleware_name = ''
|
|||
enable_openi_compatible_backend = True
|
||||
openai_system_prompt = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n"""
|
||||
expose_openai_system_prompt = True
|
||||
enable_streaming = True
|
||||
|
||||
backend_request_timeout = 30
|
||||
backend_generate_request_timeout = 120
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import json
|
||||
import sys
|
||||
|
||||
import redis as redis_pkg
|
||||
from flask_caching import Cache
|
||||
from redis import Redis
|
||||
from redis.typing import FieldT
|
||||
import redis as redis_pkg
|
||||
|
||||
cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local-llm'})
|
||||
|
||||
|
@ -20,44 +20,45 @@ class RedisWrapper:
|
|||
self.redis = Redis(**kwargs)
|
||||
self.prefix = prefix
|
||||
try:
|
||||
self.set('check_connected', 1)
|
||||
self.set('____', 1)
|
||||
except redis_pkg.exceptions.ConnectionError as e:
|
||||
print('Failed to connect to the Redis server:', e)
|
||||
print('Did you install and start Redis?')
|
||||
sys.exit(1)
|
||||
|
||||
def _key(self, key):
|
||||
return f"{self.prefix}:{key}"
|
||||
|
||||
def set(self, key, value):
|
||||
return self.redis.set(f"{self.prefix}:{key}", value)
|
||||
return self.redis.set(self._key(key), value)
|
||||
|
||||
def get(self, key):
|
||||
return self.redis.get(f"{self.prefix}:{key}")
|
||||
return self.redis.get(self._key(key))
|
||||
|
||||
def incr(self, key, amount=1):
|
||||
return self.redis.incr(f"{self.prefix}:{key}", amount)
|
||||
return self.redis.incr(self._key(key), amount)
|
||||
|
||||
def decr(self, key, amount=1):
|
||||
return self.redis.decr(f"{self.prefix}:{key}", amount)
|
||||
return self.redis.decr(self._key(key), amount)
|
||||
|
||||
def sadd(self, key: str, *values: FieldT):
|
||||
return self.redis.sadd(f"{self.prefix}:{key}", *values)
|
||||
return self.redis.sadd(self._key(key), *values)
|
||||
|
||||
def srem(self, key: str, *values: FieldT):
|
||||
return self.redis.srem(f"{self.prefix}:{key}", *values)
|
||||
return self.redis.srem(self._key(key), *values)
|
||||
|
||||
def sismember(self, key: str, value: str):
|
||||
return self.redis.sismember(f"{self.prefix}:{key}", value)
|
||||
return self.redis.sismember(self._key(key), value)
|
||||
|
||||
def set_dict(self, key, dict_value):
|
||||
# return self.redis.hset(f"{self.prefix}:{key}", mapping=dict_value)
|
||||
return self.set(f"{self.prefix}:{key}", json.dumps(dict_value))
|
||||
return self.set(self._key(key), json.dumps(dict_value))
|
||||
|
||||
def get_dict(self, key):
|
||||
# return self.redis.hgetall(f"{self.prefix}:{key}")
|
||||
r = self.get(f"{self.prefix}:{key}")
|
||||
r = self.get(self._key(key))
|
||||
if not r:
|
||||
return dict()
|
||||
else:
|
||||
return json.loads(r)
|
||||
return json.loads(r.decode("utf-8"))
|
||||
|
||||
def flush(self):
|
||||
flushed = []
|
||||
|
|
|
@ -16,45 +16,18 @@ class OobaRequestHandler(RequestHandler):
|
|||
|
||||
def handle_request(self):
|
||||
if self.used:
|
||||
raise Exception
|
||||
raise Exception('Can only use a RequestHandler object once.')
|
||||
|
||||
request_valid_json, self.request_json_body = validate_json(self.request)
|
||||
if not request_valid_json:
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
|
||||
params_valid, request_valid = self.validate_request()
|
||||
if not request_valid[0] or not params_valid[0]:
|
||||
error_messages = [msg for valid, msg in [request_valid, params_valid] if not valid and msg]
|
||||
combined_error_message = ', '.join(error_messages)
|
||||
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
||||
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
|
||||
return jsonify({
|
||||
'code': 400,
|
||||
'msg': 'parameter validation error',
|
||||
'results': [{'text': err}]
|
||||
}), 200
|
||||
request_valid, invalid_response = self.validate_request()
|
||||
if not request_valid:
|
||||
return invalid_response
|
||||
|
||||
# Reconstruct the request JSON with the validated parameters and prompt.
|
||||
prompt = self.request_json_body.get('prompt', '')
|
||||
llm_request = {**self.parameters, 'prompt': prompt}
|
||||
|
||||
if not self.is_client_ratelimited():
|
||||
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
|
||||
else:
|
||||
event = None
|
||||
|
||||
if not event:
|
||||
return self.handle_ratelimited()
|
||||
|
||||
event.wait()
|
||||
success, response, error_msg = event.data
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - self.start_time
|
||||
|
||||
self.used = True
|
||||
return self.backend.handle_response(success, self.request, response, error_msg, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
||||
_, backend_response = self.generate_response(llm_request)
|
||||
return backend_response
|
||||
|
||||
def handle_ratelimited(self):
|
||||
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
|
||||
|
|
|
@ -20,5 +20,4 @@ def openai_chat_completions():
|
|||
if not request_valid_json or not request_json_body.get('messages'):
|
||||
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
||||
else:
|
||||
handler = OpenAIRequestHandler(request)
|
||||
return handler.handle_request()
|
||||
return OpenAIRequestHandler(request).handle_request()
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
import re
|
||||
import time
|
||||
from typing import Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
import flask
|
||||
import tiktoken
|
||||
from flask import jsonify
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import log_prompt
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.helpers.http import validate_json
|
||||
from llm_server.routes.queue import priority_queue
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
@ -20,50 +20,22 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
super().__init__(*args, **kwargs)
|
||||
self.prompt = None
|
||||
|
||||
def handle_request(self):
|
||||
def handle_request(self) -> Tuple[flask.Response, int]:
|
||||
if self.used:
|
||||
raise Exception
|
||||
|
||||
request_valid_json, self.request_json_body = validate_json(self.request)
|
||||
request_valid, invalid_response = self.validate_request()
|
||||
if not request_valid:
|
||||
return invalid_response
|
||||
|
||||
self.prompt = self.transform_messages_to_prompt()
|
||||
|
||||
if not request_valid_json or not self.prompt:
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
|
||||
params_valid, request_valid = self.validate_request()
|
||||
if not request_valid[0] or not params_valid[0]:
|
||||
error_messages = [msg for valid, msg in [request_valid, params_valid] if not valid and msg]
|
||||
combined_error_message = ', '.join(error_messages)
|
||||
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
||||
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
|
||||
return jsonify({
|
||||
'code': 400,
|
||||
'msg': 'parameter validation error',
|
||||
'results': [{'text': err}]
|
||||
}), 200
|
||||
|
||||
# Reconstruct the request JSON with the validated parameters and prompt.
|
||||
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
|
||||
llm_request = {**self.parameters, 'prompt': self.prompt}
|
||||
|
||||
if not self.is_client_ratelimited():
|
||||
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
|
||||
else:
|
||||
event = None
|
||||
|
||||
if not event:
|
||||
return self.handle_ratelimited()
|
||||
|
||||
event.wait()
|
||||
success, backend_response, error_msg = event.data
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - self.start_time
|
||||
|
||||
self.used = True
|
||||
response, response_status_code = self.backend.handle_response(success=success, request=self.request, response=backend_response, error_msg=error_msg, client_ip=self.client_ip, token=self.token, prompt=self.prompt, elapsed_time=elapsed_time, parameters=self.parameters, headers=dict(self.request.headers))
|
||||
return build_openai_response(self.prompt, response.json['results'][0]['text']), 200
|
||||
_, (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
||||
return build_openai_response(self.prompt, backend_response.json['results'][0]['text']), backend_response_status_code
|
||||
|
||||
def handle_ratelimited(self):
|
||||
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
|
||||
|
@ -124,9 +96,3 @@ def build_openai_response(prompt, response):
|
|||
"total_tokens": prompt_tokens + response_tokens
|
||||
}
|
||||
})
|
||||
|
||||
# def transform_prompt_to_text(prompt: list):
|
||||
# text = ''
|
||||
# for item in prompt:
|
||||
# text += item['content'] + '\n'
|
||||
# return text.strip('\n')
|
||||
|
|
|
@ -1,13 +1,18 @@
|
|||
import sqlite3
|
||||
import time
|
||||
from typing import Union
|
||||
from typing import Tuple, Union
|
||||
|
||||
import flask
|
||||
from flask import Response, jsonify
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import log_prompt
|
||||
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.helpers.http import validate_json
|
||||
from llm_server.routes.queue import priority_queue
|
||||
from llm_server.routes.stats import SemaphoreCheckerThread
|
||||
|
||||
DEFAULT_PRIORITY = 9999
|
||||
|
@ -15,8 +20,8 @@ DEFAULT_PRIORITY = 9999
|
|||
|
||||
class RequestHandler:
|
||||
def __init__(self, incoming_request: flask.Request):
|
||||
self.request_json_body = None
|
||||
self.request = incoming_request
|
||||
_, self.request_json_body = validate_json(self.request) # routes need to validate it, here we just load it
|
||||
self.start_time = time.time()
|
||||
self.client_ip = self.get_client_ip()
|
||||
self.token = self.request.headers.get('X-Api-Key')
|
||||
|
@ -51,27 +56,103 @@ class RequestHandler:
|
|||
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
|
||||
self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
|
||||
|
||||
def validate_request(self):
|
||||
def validate_request(self) -> Tuple[bool, Tuple[Response | None, int]]:
|
||||
self.load_parameters()
|
||||
params_valid = False
|
||||
request_valid = False
|
||||
invalid_request_err_msg = None
|
||||
if self.parameters:
|
||||
params_valid = True
|
||||
request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters)
|
||||
return (params_valid, self.parameters_invalid_msg), (request_valid, invalid_request_err_msg)
|
||||
|
||||
def is_client_ratelimited(self):
|
||||
if not request_valid or not params_valid:
|
||||
error_messages = [msg for valid, msg in [request_valid, params_valid] if not valid and msg]
|
||||
combined_error_message = ', '.join(error_messages)
|
||||
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
||||
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
|
||||
return False, (jsonify({
|
||||
'code': 400,
|
||||
'msg': 'parameter validation error',
|
||||
'results': [{'text': err}]
|
||||
}), 200)
|
||||
return True, (None, 0)
|
||||
|
||||
def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]:
|
||||
if not self.is_client_ratelimited():
|
||||
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
|
||||
else:
|
||||
event = None
|
||||
|
||||
if not event:
|
||||
return (False, None, None, 0), self.handle_ratelimited()
|
||||
|
||||
prompt = llm_request['prompt']
|
||||
|
||||
event.wait()
|
||||
success, response, error_msg = event.data
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - self.start_time
|
||||
|
||||
if response:
|
||||
try:
|
||||
# Be extra careful when getting attributes from the response object
|
||||
response_status_code = response.status_code
|
||||
except:
|
||||
response_status_code = 0
|
||||
else:
|
||||
response_status_code = None
|
||||
|
||||
# ===============================================
|
||||
|
||||
# We encountered an error
|
||||
if not success or not response or error_msg:
|
||||
if not error_msg or error_msg == '':
|
||||
error_msg = 'Unknown error.'
|
||||
else:
|
||||
error_msg = error_msg.strip('.') + '.'
|
||||
backend_response = format_sillytavern_err(error_msg, 'error')
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response, None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
|
||||
return (False, None, None, 0), (jsonify({
|
||||
'code': 500,
|
||||
'msg': error_msg,
|
||||
'results': [{'text': backend_response}]
|
||||
}), 200)
|
||||
|
||||
# ===============================================
|
||||
|
||||
response_valid_json, response_json_body = validate_json(response)
|
||||
|
||||
# The backend didn't send valid JSON
|
||||
if not response_valid_json:
|
||||
error_msg = 'The backend did not return valid JSON.'
|
||||
backend_response = format_sillytavern_err(error_msg, 'error')
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response, elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
|
||||
return (False, None, None, 0), (jsonify({
|
||||
'code': 500,
|
||||
'msg': error_msg,
|
||||
'results': [{'text': backend_response}]
|
||||
}), 200)
|
||||
|
||||
# ===============================================
|
||||
|
||||
self.used = True
|
||||
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
||||
|
||||
def is_client_ratelimited(self) -> bool:
|
||||
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
|
||||
if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def handle_request(self):
|
||||
def handle_request(self) -> Tuple[flask.Response, int]:
|
||||
# Must include this in your child.
|
||||
# if self.used:
|
||||
# raise Exception('Can only use a RequestHandler object once.')
|
||||
raise NotImplementedError
|
||||
|
||||
def handle_ratelimited(self):
|
||||
def handle_ratelimited(self) -> Tuple[flask.Response, int]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
|
|
@ -82,7 +82,7 @@ def generate_stats():
|
|||
'online': online,
|
||||
'endpoints': {
|
||||
'blocking': f'https://{opts.base_client_api}',
|
||||
'streaming': f'wss://{opts.base_client_api}/v1/stream',
|
||||
'streaming': f'wss://{opts.base_client_api}/v1/stream' if opts.enable_streaming else None,
|
||||
},
|
||||
'queue': {
|
||||
'processing': active_gen_workers,
|
||||
|
|
|
@ -15,7 +15,10 @@ from ...stream import sock
|
|||
|
||||
@sock.route('/api/v1/stream') # TODO: use blueprint route???
|
||||
def stream(ws):
|
||||
return 'disabled', 401
|
||||
if not opts.enable_streaming:
|
||||
# TODO: return a formatted ST error message
|
||||
return 'disabled', 401
|
||||
|
||||
# start_time = time.time()
|
||||
# if request.headers.get('cf-connecting-ip'):
|
||||
# client_ip = request.headers.get('cf-connecting-ip')
|
||||
|
|
|
@ -25,14 +25,6 @@ class MainBackgroundThread(Thread):
|
|||
def run(self):
|
||||
while True:
|
||||
if opts.mode == 'oobabooga':
|
||||
# try:
|
||||
# r = requests.get(f'{opts.backend_url}/api/v1/model', timeout=3, verify=opts.verify_ssl)
|
||||
# opts.running_model = r.json()['result']
|
||||
# redis.set('backend_online', 1)
|
||||
# except Exception as e:
|
||||
# redis.set('backend_online', 0)
|
||||
# # TODO: handle error
|
||||
# print(e)
|
||||
model, err = get_running_model()
|
||||
if err:
|
||||
print(err)
|
||||
|
@ -52,7 +44,7 @@ class MainBackgroundThread(Thread):
|
|||
raise Exception
|
||||
|
||||
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
|
||||
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now
|
||||
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
|
||||
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model, opts.mode, opts.backend_url, exclude_zeros=True) or 0
|
||||
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
|
||||
|
||||
|
@ -65,7 +57,6 @@ class MainBackgroundThread(Thread):
|
|||
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
|
||||
# print(f'Weighted: {average_output_tokens}, overall: {overall}')
|
||||
|
||||
# Avoid division by zero
|
||||
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0
|
||||
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
|
||||
redis.set('estimated_avg_tps', estimated_avg_tps)
|
||||
time.sleep(60)
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
### Nginx
|
||||
|
||||
1. Make sure your proxies all have a long timeout:
|
||||
```
|
||||
proxy_read_timeout 300;
|
||||
proxy_connect_timeout 300;
|
||||
proxy_send_timeout 300;
|
||||
```
|
||||
The LLM middleware has a request timeout of 120 so this longer timeout is to avoid any issues.
|
|
@ -73,6 +73,7 @@ opts.llm_middleware_name = config['llm_middleware_name']
|
|||
opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend']
|
||||
opts.openai_system_prompt = config['openai_system_prompt']
|
||||
opts.expose_openai_system_prompt = config['expose_openai_system_prompt']
|
||||
opts.enable_streaming = config['enable_streaming']
|
||||
|
||||
opts.verify_ssl = config['verify_ssl']
|
||||
if not opts.verify_ssl:
|
||||
|
@ -107,8 +108,6 @@ app = Flask(__name__)
|
|||
cache.init_app(app)
|
||||
cache.clear() # clear redis cache
|
||||
init_socketio(app)
|
||||
# with app.app_context():
|
||||
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||
|
||||
|
@ -118,7 +117,8 @@ app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
|||
|
||||
@app.route('/')
|
||||
@app.route('/api')
|
||||
@cache.cached(timeout=10, query_string=True)
|
||||
@app.route('/api/openai')
|
||||
@cache.cached(timeout=60)
|
||||
def home():
|
||||
if not opts.base_client_api:
|
||||
opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
|
||||
|
@ -165,7 +165,8 @@ def home():
|
|||
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
|
||||
extra_info=mode_info,
|
||||
openai_client_api=f'https://{opts.base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
|
||||
expose_openai_system_prompt=opts.expose_openai_system_prompt
|
||||
expose_openai_system_prompt=opts.expose_openai_system_prompt,
|
||||
enable_streaming=opts.enable_streaming,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@
|
|||
<p><strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ estimated_wait }}</span></p>
|
||||
<br>
|
||||
<p><strong>Client API URL:</strong> {{ client_api }}</p>
|
||||
<p><strong>Streaming API URL:</strong> {{ ws_client_api }}</p>
|
||||
<p><strong>Streaming API URL:</strong> {{ ws_client_api if enable_streaming else 'Disabled' }}</p>
|
||||
<p><strong>OpenAI-Compatible API URL:</strong> {{ openai_client_api }}</p>
|
||||
{% if info_html|length > 1 %}
|
||||
<br>
|
||||
|
@ -93,8 +93,7 @@
|
|||
<ol>
|
||||
<li>Set your API type to <kbd>{{ mode_name }}</kbd></li>
|
||||
<li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li>
|
||||
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.
|
||||
</li>
|
||||
{% if enable_streaming %}<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.</li>{% endif %}
|
||||
<li>If you have a token, check the <kbd>Mancer AI</kbd> checkbox and enter your token in the <kbd>Mancer
|
||||
API key</kbd> textbox.
|
||||
</li>
|
||||
|
|
Reference in New Issue