add openai-compatible backend
This commit is contained in:
parent
1d9f40765e
commit
9740df07c7
|
@ -16,7 +16,9 @@ config_default_vars = {
|
|||
'simultaneous_requests_per_ip': 3,
|
||||
'show_backend_info': True,
|
||||
'max_new_tokens': 500,
|
||||
'manual_model_name': False
|
||||
'manual_model_name': False,
|
||||
'enable_openi_compatible_backend': True,
|
||||
'openai_system_prompt': """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nYou are the assistant and answer to the `### RESPONSE` prompt. Lines that start with `### ASSISTANT` were messages you sent previously.\nLines that start with `### USER` were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompts and follow the instructions given by the user.\n\n""",
|
||||
}
|
||||
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@ from typing import Tuple, Union
|
|||
|
||||
|
||||
class LLMBackend:
|
||||
default_params: dict
|
||||
|
||||
def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -19,3 +21,6 @@ class LLMBackend:
|
|||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def validate_request(self, parameters: dict) -> (bool, Union[str, None]):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -8,7 +8,7 @@ from ...routes.helpers.client import format_sillytavern_err
|
|||
from ...routes.helpers.http import validate_json
|
||||
|
||||
|
||||
class OobaboogaLLMBackend(LLMBackend):
|
||||
class OobaboogaBackend(LLMBackend):
|
||||
def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
|
||||
backend_err = False
|
||||
response_valid_json, response_json_body = validate_json(response)
|
||||
|
|
|
@ -36,7 +36,6 @@ def transform_to_text(json_request, api_response):
|
|||
data = json.loads(line[5:].strip())
|
||||
except json.decoder.JSONDecodeError:
|
||||
break
|
||||
print(data)
|
||||
if 'choices' in data:
|
||||
for choice in data['choices']:
|
||||
if 'delta' in choice and 'content' in choice['delta']:
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from typing import Tuple
|
||||
from typing import Tuple, Union
|
||||
|
||||
from flask import jsonify
|
||||
from vllm import SamplingParams
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import log_prompt
|
||||
from llm_server.llm.llm_backend import LLMBackend
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
|
@ -14,7 +15,9 @@ from llm_server.routes.helpers.http import validate_json
|
|||
# TODO: https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/proxy/middleware/common.ts?ref_type=heads#L69
|
||||
|
||||
class VLLMBackend(LLMBackend):
|
||||
def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
|
||||
default_params = vars(SamplingParams())
|
||||
|
||||
def handle_response(self, success, response, error_msg, client_ip, token, prompt: str, elapsed_time, parameters, headers):
|
||||
response_valid_json, response_json_body = validate_json(response)
|
||||
backend_err = False
|
||||
try:
|
||||
|
@ -50,18 +53,18 @@ class VLLMBackend(LLMBackend):
|
|||
}), 200
|
||||
|
||||
# def validate_params(self, params_dict: dict):
|
||||
# default_params = SamplingParams()
|
||||
# self.default_params = SamplingParams()
|
||||
# try:
|
||||
# sampling_params = SamplingParams(
|
||||
# temperature=params_dict.get('temperature', default_params.temperature),
|
||||
# top_p=params_dict.get('top_p', default_params.top_p),
|
||||
# top_k=params_dict.get('top_k', default_params.top_k),
|
||||
# temperature=params_dict.get('temperature', self.default_paramstemperature),
|
||||
# top_p=params_dict.get('top_p', self.default_paramstop_p),
|
||||
# top_k=params_dict.get('top_k', self.default_paramstop_k),
|
||||
# use_beam_search=True if params_dict['num_beams'] > 1 else False,
|
||||
# length_penalty=params_dict.get('length_penalty', default_params.length_penalty),
|
||||
# early_stopping=params_dict.get('early_stopping', default_params.early_stopping),
|
||||
# stop=params_dict.get('stopping_strings', default_params.stop),
|
||||
# length_penalty=params_dict.get('length_penalty', self.default_paramslength_penalty),
|
||||
# early_stopping=params_dict.get('early_stopping', self.default_paramsearly_stopping),
|
||||
# stop=params_dict.get('stopping_strings', self.default_paramsstop),
|
||||
# ignore_eos=params_dict.get('ban_eos_token', False),
|
||||
# max_tokens=params_dict.get('max_new_tokens', default_params.max_tokens)
|
||||
# max_tokens=params_dict.get('max_new_tokens', self.default_paramsmax_tokens)
|
||||
# )
|
||||
# except ValueError as e:
|
||||
# print(e)
|
||||
|
@ -79,30 +82,21 @@ class VLLMBackend(LLMBackend):
|
|||
# return False, e
|
||||
|
||||
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
|
||||
default_params = SamplingParams()
|
||||
try:
|
||||
sampling_params = SamplingParams(
|
||||
temperature=parameters.get('temperature', default_params.temperature),
|
||||
top_p=parameters.get('top_p', default_params.top_p),
|
||||
top_k=parameters.get('top_k', default_params.top_k),
|
||||
use_beam_search=True if parameters['num_beams'] > 1 else False,
|
||||
stop=parameters.get('stopping_strings', default_params.stop),
|
||||
temperature=parameters.get('temperature', self.default_params['temperature']),
|
||||
top_p=parameters.get('top_p', self.default_params['top_p']),
|
||||
top_k=parameters.get('top_k', self.default_params['top_k']),
|
||||
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
|
||||
stop=parameters.get('stopping_strings', self.default_params['stop']),
|
||||
ignore_eos=parameters.get('ban_eos_token', False),
|
||||
max_tokens=parameters.get('max_new_tokens', default_params.max_tokens)
|
||||
max_tokens=parameters.get('max_new_tokens', self.default_params['max_tokens'])
|
||||
)
|
||||
except ValueError as e:
|
||||
return None, str(e).strip('.')
|
||||
return vars(sampling_params), None
|
||||
|
||||
# def transform_sampling_params(params: SamplingParams):
|
||||
# return {
|
||||
# 'temperature': params['temperature'],
|
||||
# 'top_p': params['top_p'],
|
||||
# 'top_k': params['top_k'],
|
||||
# 'use_beam_search' = True if parameters['num_beams'] > 1 else False,
|
||||
# length_penalty = parameters.get('length_penalty', default_params.length_penalty),
|
||||
# early_stopping = parameters.get('early_stopping', default_params.early_stopping),
|
||||
# stop = parameters.get('stopping_strings', default_params.stop),
|
||||
# ignore_eos = parameters.get('ban_eos_token', False),
|
||||
# max_tokens = parameters.get('max_new_tokens', default_params.max_tokens)
|
||||
# }
|
||||
def validate_request(self, parameters) -> (bool, Union[str, None]):
|
||||
if parameters.get('max_new_tokens', 0) > opts.max_new_tokens:
|
||||
return False, f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}'
|
||||
return True, None
|
||||
|
|
|
@ -23,3 +23,6 @@ netdata_root = None
|
|||
simultaneous_requests_per_ip = 3
|
||||
show_backend_info = True
|
||||
manual_model_name = None
|
||||
llm_middleware_name = ''
|
||||
enable_openi_compatible_backend = True
|
||||
openai_system_prompt = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nYou are the assistant and answer to the `### RESPONSE` prompt. Lines that start with `### ASSISTANT` were messages you sent previously.\nLines that start with `### USER` were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompts and follow the instructions given by the user.\n\n"""
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import json
|
||||
from functools import wraps
|
||||
from typing import Union
|
||||
|
||||
from flask import make_response
|
||||
from requests import Response
|
||||
|
||||
import flask
|
||||
import requests
|
||||
from flask import make_response, Request
|
||||
from flask import request, jsonify
|
||||
from functools import wraps
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import is_valid_api_key
|
||||
|
@ -39,15 +39,18 @@ def require_api_key():
|
|||
return jsonify({'code': 401, 'message': 'API key required'}), 401
|
||||
|
||||
|
||||
def validate_json(data: Union[str, Response]):
|
||||
if isinstance(data, Response):
|
||||
def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response]):
|
||||
try:
|
||||
if isinstance(data, (Request, flask.Response)):
|
||||
data = data.json
|
||||
return True, data
|
||||
elif isinstance(data, requests.models.Response):
|
||||
data = data.json()
|
||||
return True, data
|
||||
except Exception as e:
|
||||
return False, None
|
||||
return False, e
|
||||
try:
|
||||
j = json.loads(data)
|
||||
j = json.loads(str(data))
|
||||
return True, j
|
||||
except Exception as e:
|
||||
return False, None
|
||||
return False, e
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
import time
|
||||
|
||||
from flask import jsonify
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import log_prompt
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.helpers.http import validate_json
|
||||
from llm_server.routes.queue import priority_queue
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
|
||||
|
||||
class OobaRequestHandler(RequestHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def handle_request(self):
|
||||
if self.used:
|
||||
raise Exception
|
||||
|
||||
request_valid_json, self.request_json_body = validate_json(self.request)
|
||||
if not request_valid_json:
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
|
||||
params_valid, request_valid = self.validate_request()
|
||||
if not request_valid[0] or not params_valid[1]:
|
||||
error_messages = [msg for valid, msg in [request_valid, params_valid] if not valid and msg]
|
||||
combined_error_message = ', '.join(error_messages)
|
||||
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)
|
||||
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
|
||||
return jsonify({
|
||||
'code': 400,
|
||||
'msg': 'parameter validation error',
|
||||
'results': [{'text': err}]
|
||||
}), 200
|
||||
|
||||
# Reconstruct the request JSON with the validated parameters and prompt.
|
||||
prompt = self.request_json_body.get('prompt', '')
|
||||
llm_request = {**self.parameters, 'prompt': prompt}
|
||||
|
||||
if not self.is_client_ratelimited():
|
||||
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
|
||||
else:
|
||||
event = None
|
||||
|
||||
if not event:
|
||||
return self.handle_ratelimited()
|
||||
|
||||
event.wait()
|
||||
success, response, error_msg = event.data
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - self.start_time
|
||||
|
||||
self.used = True
|
||||
return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
||||
|
||||
def handle_ratelimited(self):
|
||||
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, is_error=True)
|
||||
return jsonify({
|
||||
'results': [{'text': backend_response}]
|
||||
}), 200
|
|
@ -0,0 +1,32 @@
|
|||
from flask import Blueprint, request
|
||||
|
||||
from ..helpers.client import format_sillytavern_err
|
||||
from ..helpers.http import require_api_key
|
||||
from ..openai_request_handler import build_openai_response
|
||||
from ..server_error import handle_server_error
|
||||
from ... import opts
|
||||
|
||||
openai_bp = Blueprint('openai/v1/', __name__)
|
||||
|
||||
|
||||
@openai_bp.before_request
|
||||
def before_request():
|
||||
if not opts.http_host:
|
||||
opts.http_host = request.headers.get("Host")
|
||||
if not opts.enable_openi_compatible_backend:
|
||||
return build_openai_response('', format_sillytavern_err('The OpenAI-compatible backend is disabled.', 'Access Denied')), 401
|
||||
if not opts.base_client_api:
|
||||
opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
|
||||
if request.endpoint != 'v1.get_stats':
|
||||
response = require_api_key()
|
||||
if response is not None:
|
||||
return response
|
||||
|
||||
|
||||
@openai_bp.errorhandler(500)
|
||||
def handle_error(e):
|
||||
return handle_server_error(e)
|
||||
|
||||
|
||||
from .models import openai_list_models
|
||||
from .chat_completions import openai_chat_completions
|
|
@ -0,0 +1,24 @@
|
|||
from flask import jsonify, request
|
||||
|
||||
from . import openai_bp
|
||||
from ..helpers.http import validate_json
|
||||
from ..openai_request_handler import OpenAIRequestHandler
|
||||
|
||||
|
||||
class FakeFlaskRequest():
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.data = kwargs.get('data')
|
||||
self.headers = kwargs.get('headers')
|
||||
self.json = kwargs.get('json')
|
||||
self.remote_addr = kwargs.get('remote_addr')
|
||||
|
||||
|
||||
@openai_bp.route('/chat/completions', methods=['POST'])
|
||||
def openai_chat_completions():
|
||||
# TODO: make this work with oobabooga
|
||||
request_valid_json, request_json_body = validate_json(request)
|
||||
if not request_valid_json or not request_json_body.get('messages'):
|
||||
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
||||
else:
|
||||
handler = OpenAIRequestHandler(request)
|
||||
return handler.handle_request()
|
|
@ -0,0 +1,57 @@
|
|||
from flask import jsonify, request
|
||||
|
||||
from . import openai_bp
|
||||
from ..cache import cache, redis
|
||||
from ..stats import server_start_time
|
||||
from ... import opts
|
||||
from ...llm.info import get_running_model
|
||||
|
||||
|
||||
@openai_bp.route('/models', methods=['GET'])
|
||||
def openai_list_models():
|
||||
cache_key = 'openai_model_cache::' + request.url
|
||||
cached_response = cache.get(cache_key)
|
||||
|
||||
if cached_response:
|
||||
return cached_response
|
||||
|
||||
model, error = get_running_model()
|
||||
if not model:
|
||||
response = jsonify({
|
||||
'code': 502,
|
||||
'msg': 'failed to reach backend',
|
||||
'type': error.__class__.__name__
|
||||
}), 500 # return 500 so Cloudflare doesn't intercept us
|
||||
else:
|
||||
response = jsonify({
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": opts.running_model,
|
||||
"object": "model",
|
||||
"created": int(server_start_time.timestamp()),
|
||||
"owned_by": opts.llm_middleware_name,
|
||||
"permission": [
|
||||
{
|
||||
"id": opts.running_model,
|
||||
"object": "model_permission",
|
||||
"created": int(server_start_time.timestamp()),
|
||||
"allow_create_engine": False,
|
||||
"allow_sampling": False,
|
||||
"allow_logprobs": False,
|
||||
"allow_search_indices": False,
|
||||
"allow_view": True,
|
||||
"allow_fine_tuning": False,
|
||||
"organization": "*",
|
||||
"group": None,
|
||||
"is_blocking": False
|
||||
}
|
||||
],
|
||||
"root": None,
|
||||
"parent": None
|
||||
}
|
||||
]
|
||||
}), 200
|
||||
cache.set(cache_key, response, timeout=60)
|
||||
|
||||
return response
|
|
@ -0,0 +1,132 @@
|
|||
import re
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import tiktoken
|
||||
from flask import jsonify
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import log_prompt
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.helpers.http import validate_json
|
||||
from llm_server.routes.queue import priority_queue
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
class OpenAIRequestHandler(RequestHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.prompt = None
|
||||
|
||||
def handle_request(self):
|
||||
if self.used:
|
||||
raise Exception
|
||||
|
||||
request_valid_json, self.request_json_body = validate_json(self.request)
|
||||
self.prompt = self.transform_messages_to_prompt()
|
||||
|
||||
if not request_valid_json or not self.prompt:
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
|
||||
params_valid, request_valid = self.validate_request()
|
||||
if not request_valid[0] or not params_valid[0]:
|
||||
error_messages = [msg for valid, msg in [request_valid, params_valid] if not valid and msg]
|
||||
combined_error_message = ', '.join(error_messages)
|
||||
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)
|
||||
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
|
||||
return jsonify({
|
||||
'code': 400,
|
||||
'msg': 'parameter validation error',
|
||||
'results': [{'text': err}]
|
||||
}), 200
|
||||
|
||||
# Reconstruct the request JSON with the validated parameters and prompt.
|
||||
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
|
||||
llm_request = {**self.parameters, 'prompt': self.prompt}
|
||||
|
||||
if not self.is_client_ratelimited():
|
||||
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
|
||||
else:
|
||||
event = None
|
||||
|
||||
if not event:
|
||||
return self.handle_ratelimited()
|
||||
|
||||
event.wait()
|
||||
success, backend_response, error_msg = event.data
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - self.start_time
|
||||
|
||||
self.used = True
|
||||
response, response_status_code = self.backend.handle_response(success, backend_response, error_msg, self.client_ip, self.token, self.prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
||||
return build_openai_response(self.prompt, response.json['results'][0]['text']), 200
|
||||
|
||||
def handle_ratelimited(self):
|
||||
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, is_error=True)
|
||||
return build_openai_response(self.prompt, backend_response), 200
|
||||
|
||||
def transform_messages_to_prompt(self):
|
||||
try:
|
||||
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
|
||||
for msg in self.request.json['messages']:
|
||||
if not msg.get('content') or not msg.get('role'):
|
||||
return False
|
||||
if msg['role'] == 'system':
|
||||
prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
|
||||
elif msg['role'] == 'user':
|
||||
prompt += f'### USER: {msg["content"]}\n\n'
|
||||
elif msg['role'] == 'assistant':
|
||||
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
|
||||
else:
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
|
||||
prompt = prompt.strip(' ').strip('\n').strip('\n\n') # TODO: this is really lazy
|
||||
prompt += '\n\n### RESPONSE: '
|
||||
return prompt
|
||||
|
||||
|
||||
def build_openai_response(prompt, response):
|
||||
# Seperate the user's prompt from the context
|
||||
x = prompt.split('### USER:')
|
||||
if len(x) > 1:
|
||||
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
|
||||
|
||||
# Make sure the bot doesn't put any other instructions in its response
|
||||
y = response.split('\n### ')
|
||||
if len(x) > 1:
|
||||
response = re.sub(r'\n$', '', y[0].strip(' '))
|
||||
|
||||
prompt_tokens = len(tokenizer.encode(prompt))
|
||||
response_tokens = len(tokenizer.encode(response))
|
||||
return jsonify({
|
||||
"id": f"chatcmpl-{uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": opts.running_model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response,
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": response_tokens,
|
||||
"total_tokens": prompt_tokens + response_tokens
|
||||
}
|
||||
})
|
||||
|
||||
# def transform_prompt_to_text(prompt: list):
|
||||
# text = ''
|
||||
# for item in prompt:
|
||||
# text += item['content'] + '\n'
|
||||
# return text.strip('\n')
|
|
@ -2,36 +2,16 @@ import sqlite3
|
|||
import time
|
||||
from typing import Union
|
||||
|
||||
from flask import jsonify
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import log_prompt
|
||||
from llm_server.llm.hf_textgen.hf_textgen_backend import HfTextgenLLMBackend
|
||||
from llm_server.llm.oobabooga.ooba_backend import OobaboogaLLMBackend
|
||||
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.helpers.http import validate_json
|
||||
from llm_server.routes.queue import priority_queue
|
||||
from llm_server.routes.stats import SemaphoreCheckerThread
|
||||
|
||||
DEFAULT_PRIORITY = 9999
|
||||
|
||||
|
||||
def delete_dict_key(d: dict, k: Union[str, list]):
|
||||
if isinstance(k, str):
|
||||
if k in d.keys():
|
||||
del d[k]
|
||||
elif isinstance(k, list):
|
||||
for item in k:
|
||||
if item in d.keys():
|
||||
del d[item]
|
||||
else:
|
||||
raise ValueError
|
||||
return d
|
||||
|
||||
|
||||
class OobaRequestHandler:
|
||||
class RequestHandler:
|
||||
def __init__(self, incoming_request):
|
||||
self.request_json_body = None
|
||||
self.request = incoming_request
|
||||
|
@ -39,14 +19,10 @@ class OobaRequestHandler:
|
|||
self.client_ip = self.get_client_ip()
|
||||
self.token = self.request.headers.get('X-Api-Key')
|
||||
self.priority = self.get_priority()
|
||||
self.backend = self.get_backend()
|
||||
self.backend = get_backend()
|
||||
self.parameters = self.parameters_invalid_msg = None
|
||||
|
||||
def validate_request(self) -> (bool, Union[str, None]):
|
||||
# TODO: move this to LLMBackend
|
||||
if self.parameters.get('max_new_tokens', 0) > opts.max_new_tokens or self.parameters.get('max_tokens', 0) > opts.max_new_tokens:
|
||||
return False, f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}'
|
||||
return True, None
|
||||
self.used = False
|
||||
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
|
||||
|
||||
def get_client_ip(self):
|
||||
if self.request.headers.get('cf-connecting-ip'):
|
||||
|
@ -67,69 +43,53 @@ class OobaRequestHandler:
|
|||
return result[0]
|
||||
return DEFAULT_PRIORITY
|
||||
|
||||
def get_backend(self):
|
||||
if opts.mode == 'oobabooga':
|
||||
return OobaboogaLLMBackend()
|
||||
elif opts.mode == 'vllm':
|
||||
return VLLMBackend()
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
def get_parameters(self):
|
||||
def load_parameters(self):
|
||||
# Handle OpenAI
|
||||
if self.request_json_body.get('max_tokens'):
|
||||
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
|
||||
self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
|
||||
|
||||
def handle_request(self):
|
||||
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
|
||||
|
||||
request_valid_json, self.request_json_body = validate_json(self.request.data)
|
||||
if not request_valid_json:
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
|
||||
self.get_parameters()
|
||||
def validate_request(self):
|
||||
self.load_parameters()
|
||||
params_valid = False
|
||||
request_valid = False
|
||||
invalid_request_err_msg = None
|
||||
if self.parameters:
|
||||
params_valid = True
|
||||
request_valid, invalid_request_err_msg = self.validate_request()
|
||||
|
||||
if not request_valid or not params_valid:
|
||||
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid and msg]
|
||||
combined_error_message = ', '.join(error_messages)
|
||||
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)
|
||||
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
|
||||
return jsonify({
|
||||
'code': 400,
|
||||
'msg': 'parameter validation error',
|
||||
'results': [{'text': err}]
|
||||
}), 200
|
||||
|
||||
# Reconstruct the request JSON with the validated parameters and prompt.
|
||||
prompt = self.request_json_body.get('prompt', '')
|
||||
llm_request = {**self.parameters, 'prompt': prompt}
|
||||
request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters)
|
||||
return (params_valid, self.parameters_invalid_msg), (request_valid, invalid_request_err_msg)
|
||||
|
||||
def is_client_ratelimited(self):
|
||||
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
|
||||
if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0:
|
||||
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
|
||||
return False
|
||||
else:
|
||||
# Client was rate limited
|
||||
event = None
|
||||
return True
|
||||
|
||||
if not event:
|
||||
return self.handle_ratelimited()
|
||||
|
||||
event.wait()
|
||||
success, response, error_msg = event.data
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - self.start_time
|
||||
|
||||
return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
||||
def handle_request(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def handle_ratelimited(self):
|
||||
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, is_error=True)
|
||||
return jsonify({
|
||||
'results': [{'text': backend_response}]
|
||||
}), 200
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_backend():
|
||||
if opts.mode == 'oobabooga':
|
||||
return OobaboogaBackend()
|
||||
elif opts.mode == 'vllm':
|
||||
return VLLMBackend()
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
|
||||
def delete_dict_key(d: dict, k: Union[str, list]):
|
||||
if isinstance(k, str):
|
||||
if k in d.keys():
|
||||
del d[k]
|
||||
elif isinstance(k, list):
|
||||
for item in k:
|
||||
if item in d.keys():
|
||||
del d[item]
|
||||
else:
|
||||
raise ValueError
|
||||
return d
|
||||
|
|
|
@ -6,6 +6,7 @@ from llm_server.routes.cache import redis
|
|||
|
||||
# proompters_1_min = 0
|
||||
# concurrent_semaphore = Semaphore(concurrent_gens)
|
||||
|
||||
server_start_time = datetime.now()
|
||||
|
||||
# TODO: have a background thread put the averages in a variable so we don't end up with massive arrays
|
||||
|
|
|
@ -2,13 +2,13 @@ from flask import jsonify, request
|
|||
|
||||
from . import bp
|
||||
from ..helpers.http import validate_json
|
||||
from ..request_handler import OobaRequestHandler
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ... import opts
|
||||
|
||||
|
||||
@bp.route('/generate', methods=['POST'])
|
||||
def generate():
|
||||
request_valid_json, request_json_body = validate_json(request.data)
|
||||
request_valid_json, request_json_body = validate_json(request)
|
||||
if not request_valid_json or not request_json_body.get('prompt'):
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
else:
|
||||
|
|
|
@ -62,10 +62,10 @@ def generate_stats():
|
|||
'power_state': power_state,
|
||||
# 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu')))
|
||||
}
|
||||
|
||||
else:
|
||||
netdata_stats = {}
|
||||
|
||||
|
||||
output = {
|
||||
'stats': {
|
||||
'proompters': {
|
||||
|
|
|
@ -8,7 +8,6 @@ from ..helpers.client import format_sillytavern_err
|
|||
from ... import opts
|
||||
from ...database import log_prompt
|
||||
from ...helpers import indefinite_article
|
||||
from ...llm.hf_textgen.generate import prepare_json
|
||||
from ...stream import sock
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import time
|
||||
from datetime import datetime
|
||||
from threading import Thread
|
||||
|
||||
import requests
|
||||
|
|
|
@ -6,6 +6,7 @@ from threading import Thread
|
|||
|
||||
from flask import Flask, jsonify, render_template, request
|
||||
|
||||
from llm_server.routes.openai import openai_bp
|
||||
from llm_server.routes.server_error import handle_server_error
|
||||
|
||||
try:
|
||||
|
@ -68,6 +69,9 @@ opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
|
|||
opts.show_backend_info = config['show_backend_info']
|
||||
opts.max_new_tokens = config['max_new_tokens']
|
||||
opts.manual_model_name = config['manual_model_name']
|
||||
opts.llm_middleware_name = config['llm_middleware_name']
|
||||
opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend']
|
||||
opts.openai_system_prompt = config['openai_system_prompt']
|
||||
|
||||
opts.verify_ssl = config['verify_ssl']
|
||||
if not opts.verify_ssl:
|
||||
|
@ -105,6 +109,7 @@ init_socketio(app)
|
|||
# with app.app_context():
|
||||
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||
|
||||
|
||||
# print(app.url_map)
|
||||
|
@ -145,7 +150,7 @@ def home():
|
|||
mode_info = vllm_info
|
||||
|
||||
return render_template('home.html',
|
||||
llm_middleware_name=config['llm_middleware_name'],
|
||||
llm_middleware_name=opts.llm_middleware_name,
|
||||
analytics_tracking_code=analytics_tracking_code,
|
||||
info_html=info_html,
|
||||
current_model=opts.manual_model_name if opts.manual_model_name else running_model,
|
||||
|
@ -158,6 +163,7 @@ def home():
|
|||
context_size=opts.context_size,
|
||||
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
|
||||
extra_info=mode_info,
|
||||
openai_client_api=f'https://{opts.base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -76,6 +76,7 @@
|
|||
<p><strong>Current Model:</strong> <span id="model">{{ current_model }}</span></p>
|
||||
<p><strong>Client API URL:</strong> {{ client_api }}</p>
|
||||
<p><strong>Streaming API URL:</strong> {{ ws_client_api }}</p>
|
||||
<p><strong>OpenAI-Compatible API URL:</strong> {{ openai_client_api }}</p>
|
||||
<p><strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ estimated_wait }}</span></p>
|
||||
{{ info_html|safe }}
|
||||
</div>
|
||||
|
|
Reference in New Issue