add openai-compatible backend

This commit is contained in:
Cyberes 2023-09-12 16:40:09 -06:00
parent 1d9f40765e
commit 9740df07c7
20 changed files with 412 additions and 129 deletions

View File

@ -16,7 +16,9 @@ config_default_vars = {
'simultaneous_requests_per_ip': 3, 'simultaneous_requests_per_ip': 3,
'show_backend_info': True, 'show_backend_info': True,
'max_new_tokens': 500, 'max_new_tokens': 500,
'manual_model_name': False 'manual_model_name': False,
'enable_openi_compatible_backend': True,
'openai_system_prompt': """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nYou are the assistant and answer to the `### RESPONSE` prompt. Lines that start with `### ASSISTANT` were messages you sent previously.\nLines that start with `### USER` were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompts and follow the instructions given by the user.\n\n""",
} }
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name'] config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']

View File

@ -2,6 +2,8 @@ from typing import Tuple, Union
class LLMBackend: class LLMBackend:
default_params: dict
def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers): def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
raise NotImplementedError raise NotImplementedError
@ -19,3 +21,6 @@ class LLMBackend:
:return: :return:
""" """
raise NotImplementedError raise NotImplementedError
def validate_request(self, parameters: dict) -> (bool, Union[str, None]):
raise NotImplementedError

View File

@ -8,7 +8,7 @@ from ...routes.helpers.client import format_sillytavern_err
from ...routes.helpers.http import validate_json from ...routes.helpers.http import validate_json
class OobaboogaLLMBackend(LLMBackend): class OobaboogaBackend(LLMBackend):
def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers): def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
backend_err = False backend_err = False
response_valid_json, response_json_body = validate_json(response) response_valid_json, response_json_body = validate_json(response)

View File

@ -36,7 +36,6 @@ def transform_to_text(json_request, api_response):
data = json.loads(line[5:].strip()) data = json.loads(line[5:].strip())
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
break break
print(data)
if 'choices' in data: if 'choices' in data:
for choice in data['choices']: for choice in data['choices']:
if 'delta' in choice and 'content' in choice['delta']: if 'delta' in choice and 'content' in choice['delta']:

View File

@ -1,8 +1,9 @@
from typing import Tuple from typing import Tuple, Union
from flask import jsonify from flask import jsonify
from vllm import SamplingParams from vllm import SamplingParams
from llm_server import opts
from llm_server.database import log_prompt from llm_server.database import log_prompt
from llm_server.llm.llm_backend import LLMBackend from llm_server.llm.llm_backend import LLMBackend
from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.helpers.client import format_sillytavern_err
@ -14,7 +15,9 @@ from llm_server.routes.helpers.http import validate_json
# TODO: https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/proxy/middleware/common.ts?ref_type=heads#L69 # TODO: https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/proxy/middleware/common.ts?ref_type=heads#L69
class VLLMBackend(LLMBackend): class VLLMBackend(LLMBackend):
def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers): default_params = vars(SamplingParams())
def handle_response(self, success, response, error_msg, client_ip, token, prompt: str, elapsed_time, parameters, headers):
response_valid_json, response_json_body = validate_json(response) response_valid_json, response_json_body = validate_json(response)
backend_err = False backend_err = False
try: try:
@ -50,18 +53,18 @@ class VLLMBackend(LLMBackend):
}), 200 }), 200
# def validate_params(self, params_dict: dict): # def validate_params(self, params_dict: dict):
# default_params = SamplingParams() # self.default_params = SamplingParams()
# try: # try:
# sampling_params = SamplingParams( # sampling_params = SamplingParams(
# temperature=params_dict.get('temperature', default_params.temperature), # temperature=params_dict.get('temperature', self.default_paramstemperature),
# top_p=params_dict.get('top_p', default_params.top_p), # top_p=params_dict.get('top_p', self.default_paramstop_p),
# top_k=params_dict.get('top_k', default_params.top_k), # top_k=params_dict.get('top_k', self.default_paramstop_k),
# use_beam_search=True if params_dict['num_beams'] > 1 else False, # use_beam_search=True if params_dict['num_beams'] > 1 else False,
# length_penalty=params_dict.get('length_penalty', default_params.length_penalty), # length_penalty=params_dict.get('length_penalty', self.default_paramslength_penalty),
# early_stopping=params_dict.get('early_stopping', default_params.early_stopping), # early_stopping=params_dict.get('early_stopping', self.default_paramsearly_stopping),
# stop=params_dict.get('stopping_strings', default_params.stop), # stop=params_dict.get('stopping_strings', self.default_paramsstop),
# ignore_eos=params_dict.get('ban_eos_token', False), # ignore_eos=params_dict.get('ban_eos_token', False),
# max_tokens=params_dict.get('max_new_tokens', default_params.max_tokens) # max_tokens=params_dict.get('max_new_tokens', self.default_paramsmax_tokens)
# ) # )
# except ValueError as e: # except ValueError as e:
# print(e) # print(e)
@ -79,30 +82,21 @@ class VLLMBackend(LLMBackend):
# return False, e # return False, e
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]: def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
default_params = SamplingParams()
try: try:
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=parameters.get('temperature', default_params.temperature), temperature=parameters.get('temperature', self.default_params['temperature']),
top_p=parameters.get('top_p', default_params.top_p), top_p=parameters.get('top_p', self.default_params['top_p']),
top_k=parameters.get('top_k', default_params.top_k), top_k=parameters.get('top_k', self.default_params['top_k']),
use_beam_search=True if parameters['num_beams'] > 1 else False, use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
stop=parameters.get('stopping_strings', default_params.stop), stop=parameters.get('stopping_strings', self.default_params['stop']),
ignore_eos=parameters.get('ban_eos_token', False), ignore_eos=parameters.get('ban_eos_token', False),
max_tokens=parameters.get('max_new_tokens', default_params.max_tokens) max_tokens=parameters.get('max_new_tokens', self.default_params['max_tokens'])
) )
except ValueError as e: except ValueError as e:
return None, str(e).strip('.') return None, str(e).strip('.')
return vars(sampling_params), None return vars(sampling_params), None
# def transform_sampling_params(params: SamplingParams): def validate_request(self, parameters) -> (bool, Union[str, None]):
# return { if parameters.get('max_new_tokens', 0) > opts.max_new_tokens:
# 'temperature': params['temperature'], return False, f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}'
# 'top_p': params['top_p'], return True, None
# 'top_k': params['top_k'],
# 'use_beam_search' = True if parameters['num_beams'] > 1 else False,
# length_penalty = parameters.get('length_penalty', default_params.length_penalty),
# early_stopping = parameters.get('early_stopping', default_params.early_stopping),
# stop = parameters.get('stopping_strings', default_params.stop),
# ignore_eos = parameters.get('ban_eos_token', False),
# max_tokens = parameters.get('max_new_tokens', default_params.max_tokens)
# }

View File

@ -23,3 +23,6 @@ netdata_root = None
simultaneous_requests_per_ip = 3 simultaneous_requests_per_ip = 3
show_backend_info = True show_backend_info = True
manual_model_name = None manual_model_name = None
llm_middleware_name = ''
enable_openi_compatible_backend = True
openai_system_prompt = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nYou are the assistant and answer to the `### RESPONSE` prompt. Lines that start with `### ASSISTANT` were messages you sent previously.\nLines that start with `### USER` were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompts and follow the instructions given by the user.\n\n"""

View File

@ -1,11 +1,11 @@
import json import json
from functools import wraps
from typing import Union from typing import Union
from flask import make_response import flask
from requests import Response import requests
from flask import make_response, Request
from flask import request, jsonify from flask import request, jsonify
from functools import wraps
from llm_server import opts from llm_server import opts
from llm_server.database import is_valid_api_key from llm_server.database import is_valid_api_key
@ -39,15 +39,18 @@ def require_api_key():
return jsonify({'code': 401, 'message': 'API key required'}), 401 return jsonify({'code': 401, 'message': 'API key required'}), 401
def validate_json(data: Union[str, Response]): def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response]):
if isinstance(data, Response):
try: try:
if isinstance(data, (Request, flask.Response)):
data = data.json
return True, data
elif isinstance(data, requests.models.Response):
data = data.json() data = data.json()
return True, data return True, data
except Exception as e: except Exception as e:
return False, None return False, e
try: try:
j = json.loads(data) j = json.loads(str(data))
return True, j return True, j
except Exception as e: except Exception as e:
return False, None return False, e

View File

@ -0,0 +1,64 @@
import time
from flask import jsonify
from llm_server import opts
from llm_server.database import log_prompt
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.helpers.http import validate_json
from llm_server.routes.queue import priority_queue
from llm_server.routes.request_handler import RequestHandler
class OobaRequestHandler(RequestHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def handle_request(self):
if self.used:
raise Exception
request_valid_json, self.request_json_body = validate_json(self.request)
if not request_valid_json:
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
params_valid, request_valid = self.validate_request()
if not request_valid[0] or not params_valid[1]:
error_messages = [msg for valid, msg in [request_valid, params_valid] if not valid and msg]
combined_error_message = ', '.join(error_messages)
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
return jsonify({
'code': 400,
'msg': 'parameter validation error',
'results': [{'text': err}]
}), 200
# Reconstruct the request JSON with the validated parameters and prompt.
prompt = self.request_json_body.get('prompt', '')
llm_request = {**self.parameters, 'prompt': prompt}
if not self.is_client_ratelimited():
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
else:
event = None
if not event:
return self.handle_ratelimited()
event.wait()
success, response, error_msg = event.data
end_time = time.time()
elapsed_time = end_time - self.start_time
self.used = True
return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
def handle_ratelimited(self):
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, is_error=True)
return jsonify({
'results': [{'text': backend_response}]
}), 200

View File

@ -0,0 +1,32 @@
from flask import Blueprint, request
from ..helpers.client import format_sillytavern_err
from ..helpers.http import require_api_key
from ..openai_request_handler import build_openai_response
from ..server_error import handle_server_error
from ... import opts
openai_bp = Blueprint('openai/v1/', __name__)
@openai_bp.before_request
def before_request():
if not opts.http_host:
opts.http_host = request.headers.get("Host")
if not opts.enable_openi_compatible_backend:
return build_openai_response('', format_sillytavern_err('The OpenAI-compatible backend is disabled.', 'Access Denied')), 401
if not opts.base_client_api:
opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
if request.endpoint != 'v1.get_stats':
response = require_api_key()
if response is not None:
return response
@openai_bp.errorhandler(500)
def handle_error(e):
return handle_server_error(e)
from .models import openai_list_models
from .chat_completions import openai_chat_completions

View File

@ -0,0 +1,24 @@
from flask import jsonify, request
from . import openai_bp
from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler
class FakeFlaskRequest():
def __init__(self, *args, **kwargs):
self.data = kwargs.get('data')
self.headers = kwargs.get('headers')
self.json = kwargs.get('json')
self.remote_addr = kwargs.get('remote_addr')
@openai_bp.route('/chat/completions', methods=['POST'])
def openai_chat_completions():
# TODO: make this work with oobabooga
request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('messages'):
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
else:
handler = OpenAIRequestHandler(request)
return handler.handle_request()

View File

@ -0,0 +1,57 @@
from flask import jsonify, request
from . import openai_bp
from ..cache import cache, redis
from ..stats import server_start_time
from ... import opts
from ...llm.info import get_running_model
@openai_bp.route('/models', methods=['GET'])
def openai_list_models():
cache_key = 'openai_model_cache::' + request.url
cached_response = cache.get(cache_key)
if cached_response:
return cached_response
model, error = get_running_model()
if not model:
response = jsonify({
'code': 502,
'msg': 'failed to reach backend',
'type': error.__class__.__name__
}), 500 # return 500 so Cloudflare doesn't intercept us
else:
response = jsonify({
"object": "list",
"data": [
{
"id": opts.running_model,
"object": "model",
"created": int(server_start_time.timestamp()),
"owned_by": opts.llm_middleware_name,
"permission": [
{
"id": opts.running_model,
"object": "model_permission",
"created": int(server_start_time.timestamp()),
"allow_create_engine": False,
"allow_sampling": False,
"allow_logprobs": False,
"allow_search_indices": False,
"allow_view": True,
"allow_fine_tuning": False,
"organization": "*",
"group": None,
"is_blocking": False
}
],
"root": None,
"parent": None
}
]
}), 200
cache.set(cache_key, response, timeout=60)
return response

View File

@ -0,0 +1,132 @@
import re
import time
from uuid import uuid4
import tiktoken
from flask import jsonify
from llm_server import opts
from llm_server.database import log_prompt
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.helpers.http import validate_json
from llm_server.routes.queue import priority_queue
from llm_server.routes.request_handler import RequestHandler
tokenizer = tiktoken.get_encoding("cl100k_base")
class OpenAIRequestHandler(RequestHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt = None
def handle_request(self):
if self.used:
raise Exception
request_valid_json, self.request_json_body = validate_json(self.request)
self.prompt = self.transform_messages_to_prompt()
if not request_valid_json or not self.prompt:
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
params_valid, request_valid = self.validate_request()
if not request_valid[0] or not params_valid[0]:
error_messages = [msg for valid, msg in [request_valid, params_valid] if not valid and msg]
combined_error_message = ', '.join(error_messages)
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
return jsonify({
'code': 400,
'msg': 'parameter validation error',
'results': [{'text': err}]
}), 200
# Reconstruct the request JSON with the validated parameters and prompt.
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
llm_request = {**self.parameters, 'prompt': self.prompt}
if not self.is_client_ratelimited():
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
else:
event = None
if not event:
return self.handle_ratelimited()
event.wait()
success, backend_response, error_msg = event.data
end_time = time.time()
elapsed_time = end_time - self.start_time
self.used = True
response, response_status_code = self.backend.handle_response(success, backend_response, error_msg, self.client_ip, self.token, self.prompt, elapsed_time, self.parameters, dict(self.request.headers))
return build_openai_response(self.prompt, response.json['results'][0]['text']), 200
def handle_ratelimited(self):
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, is_error=True)
return build_openai_response(self.prompt, backend_response), 200
def transform_messages_to_prompt(self):
try:
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
for msg in self.request.json['messages']:
if not msg.get('content') or not msg.get('role'):
return False
if msg['role'] == 'system':
prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
elif msg['role'] == 'user':
prompt += f'### USER: {msg["content"]}\n\n'
elif msg['role'] == 'assistant':
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
else:
return False
except:
return False
prompt = prompt.strip(' ').strip('\n').strip('\n\n') # TODO: this is really lazy
prompt += '\n\n### RESPONSE: '
return prompt
def build_openai_response(prompt, response):
# Seperate the user's prompt from the context
x = prompt.split('### USER:')
if len(x) > 1:
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
# Make sure the bot doesn't put any other instructions in its response
y = response.split('\n### ')
if len(x) > 1:
response = re.sub(r'\n$', '', y[0].strip(' '))
prompt_tokens = len(tokenizer.encode(prompt))
response_tokens = len(tokenizer.encode(response))
return jsonify({
"id": f"chatcmpl-{uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": opts.running_model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response,
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
})
# def transform_prompt_to_text(prompt: list):
# text = ''
# for item in prompt:
# text += item['content'] + '\n'
# return text.strip('\n')

View File

@ -2,36 +2,16 @@ import sqlite3
import time import time
from typing import Union from typing import Union
from flask import jsonify
from llm_server import opts from llm_server import opts
from llm_server.database import log_prompt from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.hf_textgen.hf_textgen_backend import HfTextgenLLMBackend
from llm_server.llm.oobabooga.ooba_backend import OobaboogaLLMBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.cache import redis from llm_server.routes.cache import redis
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.helpers.http import validate_json
from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import SemaphoreCheckerThread from llm_server.routes.stats import SemaphoreCheckerThread
DEFAULT_PRIORITY = 9999 DEFAULT_PRIORITY = 9999
def delete_dict_key(d: dict, k: Union[str, list]): class RequestHandler:
if isinstance(k, str):
if k in d.keys():
del d[k]
elif isinstance(k, list):
for item in k:
if item in d.keys():
del d[item]
else:
raise ValueError
return d
class OobaRequestHandler:
def __init__(self, incoming_request): def __init__(self, incoming_request):
self.request_json_body = None self.request_json_body = None
self.request = incoming_request self.request = incoming_request
@ -39,14 +19,10 @@ class OobaRequestHandler:
self.client_ip = self.get_client_ip() self.client_ip = self.get_client_ip()
self.token = self.request.headers.get('X-Api-Key') self.token = self.request.headers.get('X-Api-Key')
self.priority = self.get_priority() self.priority = self.get_priority()
self.backend = self.get_backend() self.backend = get_backend()
self.parameters = self.parameters_invalid_msg = None self.parameters = self.parameters_invalid_msg = None
self.used = False
def validate_request(self) -> (bool, Union[str, None]): SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
# TODO: move this to LLMBackend
if self.parameters.get('max_new_tokens', 0) > opts.max_new_tokens or self.parameters.get('max_tokens', 0) > opts.max_new_tokens:
return False, f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}'
return True, None
def get_client_ip(self): def get_client_ip(self):
if self.request.headers.get('cf-connecting-ip'): if self.request.headers.get('cf-connecting-ip'):
@ -67,69 +43,53 @@ class OobaRequestHandler:
return result[0] return result[0]
return DEFAULT_PRIORITY return DEFAULT_PRIORITY
def get_backend(self): def load_parameters(self):
if opts.mode == 'oobabooga': # Handle OpenAI
return OobaboogaLLMBackend() if self.request_json_body.get('max_tokens'):
elif opts.mode == 'vllm': self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
return VLLMBackend()
else:
raise Exception
def get_parameters(self):
self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body) self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
def handle_request(self): def validate_request(self):
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time() self.load_parameters()
request_valid_json, self.request_json_body = validate_json(self.request.data)
if not request_valid_json:
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
self.get_parameters()
params_valid = False params_valid = False
request_valid = False request_valid = False
invalid_request_err_msg = None invalid_request_err_msg = None
if self.parameters: if self.parameters:
params_valid = True params_valid = True
request_valid, invalid_request_err_msg = self.validate_request() request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters)
return (params_valid, self.parameters_invalid_msg), (request_valid, invalid_request_err_msg)
if not request_valid or not params_valid:
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid and msg]
combined_error_message = ', '.join(error_messages)
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
return jsonify({
'code': 400,
'msg': 'parameter validation error',
'results': [{'text': err}]
}), 200
# Reconstruct the request JSON with the validated parameters and prompt.
prompt = self.request_json_body.get('prompt', '')
llm_request = {**self.parameters, 'prompt': prompt}
def is_client_ratelimited(self):
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0) queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0: if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0:
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority) return False
else: else:
# Client was rate limited return True
event = None
if not event: def handle_request(self):
return self.handle_ratelimited() raise NotImplementedError
event.wait()
success, response, error_msg = event.data
end_time = time.time()
elapsed_time = end_time - self.start_time
return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
def handle_ratelimited(self): def handle_ratelimited(self):
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') raise NotImplementedError
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, is_error=True)
return jsonify({
'results': [{'text': backend_response}] def get_backend():
}), 200 if opts.mode == 'oobabooga':
return OobaboogaBackend()
elif opts.mode == 'vllm':
return VLLMBackend()
else:
raise Exception
def delete_dict_key(d: dict, k: Union[str, list]):
if isinstance(k, str):
if k in d.keys():
del d[k]
elif isinstance(k, list):
for item in k:
if item in d.keys():
del d[item]
else:
raise ValueError
return d

View File

@ -6,6 +6,7 @@ from llm_server.routes.cache import redis
# proompters_1_min = 0 # proompters_1_min = 0
# concurrent_semaphore = Semaphore(concurrent_gens) # concurrent_semaphore = Semaphore(concurrent_gens)
server_start_time = datetime.now() server_start_time = datetime.now()
# TODO: have a background thread put the averages in a variable so we don't end up with massive arrays # TODO: have a background thread put the averages in a variable so we don't end up with massive arrays

View File

@ -2,13 +2,13 @@ from flask import jsonify, request
from . import bp from . import bp
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
from ... import opts from ... import opts
@bp.route('/generate', methods=['POST']) @bp.route('/generate', methods=['POST'])
def generate(): def generate():
request_valid_json, request_json_body = validate_json(request.data) request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('prompt'): if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else: else:

View File

@ -62,10 +62,10 @@ def generate_stats():
'power_state': power_state, 'power_state': power_state,
# 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu'))) # 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu')))
} }
else: else:
netdata_stats = {} netdata_stats = {}
output = { output = {
'stats': { 'stats': {
'proompters': { 'proompters': {

View File

@ -8,7 +8,6 @@ from ..helpers.client import format_sillytavern_err
from ... import opts from ... import opts
from ...database import log_prompt from ...database import log_prompt
from ...helpers import indefinite_article from ...helpers import indefinite_article
from ...llm.hf_textgen.generate import prepare_json
from ...stream import sock from ...stream import sock

View File

@ -1,4 +1,5 @@
import time import time
from datetime import datetime
from threading import Thread from threading import Thread
import requests import requests

View File

@ -6,6 +6,7 @@ from threading import Thread
from flask import Flask, jsonify, render_template, request from flask import Flask, jsonify, render_template, request
from llm_server.routes.openai import openai_bp
from llm_server.routes.server_error import handle_server_error from llm_server.routes.server_error import handle_server_error
try: try:
@ -68,6 +69,9 @@ opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
opts.show_backend_info = config['show_backend_info'] opts.show_backend_info = config['show_backend_info']
opts.max_new_tokens = config['max_new_tokens'] opts.max_new_tokens = config['max_new_tokens']
opts.manual_model_name = config['manual_model_name'] opts.manual_model_name = config['manual_model_name']
opts.llm_middleware_name = config['llm_middleware_name']
opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend']
opts.openai_system_prompt = config['openai_system_prompt']
opts.verify_ssl = config['verify_ssl'] opts.verify_ssl = config['verify_ssl']
if not opts.verify_ssl: if not opts.verify_ssl:
@ -105,6 +109,7 @@ init_socketio(app)
# with app.app_context(): # with app.app_context():
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base") # current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
app.register_blueprint(bp, url_prefix='/api/v1/') app.register_blueprint(bp, url_prefix='/api/v1/')
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
# print(app.url_map) # print(app.url_map)
@ -145,7 +150,7 @@ def home():
mode_info = vllm_info mode_info = vllm_info
return render_template('home.html', return render_template('home.html',
llm_middleware_name=config['llm_middleware_name'], llm_middleware_name=opts.llm_middleware_name,
analytics_tracking_code=analytics_tracking_code, analytics_tracking_code=analytics_tracking_code,
info_html=info_html, info_html=info_html,
current_model=opts.manual_model_name if opts.manual_model_name else running_model, current_model=opts.manual_model_name if opts.manual_model_name else running_model,
@ -158,6 +163,7 @@ def home():
context_size=opts.context_size, context_size=opts.context_size,
stats_json=json.dumps(stats, indent=4, ensure_ascii=False), stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
extra_info=mode_info, extra_info=mode_info,
openai_client_api=f'https://{opts.base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
) )

View File

@ -76,6 +76,7 @@
<p><strong>Current Model:</strong> <span id="model">{{ current_model }}</span></p> <p><strong>Current Model:</strong> <span id="model">{{ current_model }}</span></p>
<p><strong>Client API URL:</strong> {{ client_api }}</p> <p><strong>Client API URL:</strong> {{ client_api }}</p>
<p><strong>Streaming API URL:</strong> {{ ws_client_api }}</p> <p><strong>Streaming API URL:</strong> {{ ws_client_api }}</p>
<p><strong>OpenAI-Compatible API URL:</strong> {{ openai_client_api }}</p>
<p><strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ estimated_wait }}</span></p> <p><strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ estimated_wait }}</span></p>
{{ info_html|safe }} {{ info_html|safe }}
</div> </div>