update openai endpoints
This commit is contained in:
parent
93d19fb95b
commit
2a3ff7e21e
|
@ -43,6 +43,8 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database.
|
||||||
|
|
||||||
### Use
|
### Use
|
||||||
|
|
||||||
|
If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis.
|
||||||
|
|
||||||
Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app`
|
Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app`
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -60,7 +60,8 @@ def get_model_choices(regen: bool = False):
|
||||||
if len(context_size):
|
if len(context_size):
|
||||||
model_choices[model]['context_size'] = min(context_size)
|
model_choices[model]['context_size'] = min(context_size)
|
||||||
|
|
||||||
model_choices = dict(sorted(model_choices.items()))
|
# Python wants to sort lowercase vs. uppercase letters differently.
|
||||||
|
model_choices = dict(sorted(model_choices.items(), key=lambda item: item[0].upper()))
|
||||||
|
|
||||||
default_backend = get_a_cluster_backend()
|
default_backend = get_a_cluster_backend()
|
||||||
default_backend_dict = {}
|
default_backend_dict = {}
|
||||||
|
|
|
@ -2,7 +2,6 @@ from typing import Tuple, Union
|
||||||
|
|
||||||
import flask
|
import flask
|
||||||
|
|
||||||
from llm_server import opts
|
|
||||||
from llm_server.cluster.cluster_config import cluster_config
|
from llm_server.cluster.cluster_config import cluster_config
|
||||||
from llm_server.custom_redis import redis
|
from llm_server.custom_redis import redis
|
||||||
from llm_server.llm import get_token_count
|
from llm_server.llm import get_token_count
|
||||||
|
@ -36,6 +35,8 @@ class LLMBackend:
|
||||||
"""
|
"""
|
||||||
If a backend needs to do other checks not related to the prompt or parameters.
|
If a backend needs to do other checks not related to the prompt or parameters.
|
||||||
Default is no extra checks preformed.
|
Default is no extra checks preformed.
|
||||||
|
:param request:
|
||||||
|
:param prompt:
|
||||||
:param parameters:
|
:param parameters:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
from flask import jsonify
|
||||||
|
|
||||||
|
from llm_server import opts
|
||||||
|
|
||||||
|
|
||||||
|
def oai_to_vllm(request_json_body, hashes: bool, mode):
|
||||||
|
if not request_json_body.get('stop'):
|
||||||
|
request_json_body['stop'] = []
|
||||||
|
|
||||||
|
if hashes:
|
||||||
|
request_json_body['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
|
||||||
|
if opts.openai_force_no_hashes:
|
||||||
|
request_json_body['stop'].append('### ')
|
||||||
|
else:
|
||||||
|
request_json_body['stop'].extend(['\nuser:', '\nassistant:'])
|
||||||
|
|
||||||
|
if request_json_body.get('frequency_penalty', 0) < -2:
|
||||||
|
request_json_body['frequency_penalty'] = -2
|
||||||
|
elif request_json_body.get('frequency_penalty', 0) > 2:
|
||||||
|
request_json_body['frequency_penalty'] = 2
|
||||||
|
|
||||||
|
if mode == 'vllm' and request_json_body.get('top_p') == 0:
|
||||||
|
request_json_body['top_p'] = 0.01
|
||||||
|
|
||||||
|
return request_json_body
|
||||||
|
|
||||||
|
|
||||||
|
def format_oai_err(err_msg):
|
||||||
|
return jsonify({
|
||||||
|
"error": {
|
||||||
|
"message": err_msg,
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"param": None,
|
||||||
|
"code": None
|
||||||
|
}
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
|
||||||
|
def validate_oai(parameters):
|
||||||
|
if parameters['temperature'] > 2:
|
||||||
|
return format_oai_err(f"{parameters['temperature']} is greater than the maximum of 2 - 'temperature'")
|
||||||
|
if parameters['temperature'] < 0:
|
||||||
|
return format_oai_err(f"{parameters['temperature']} less than the minimum of 0 - 'temperature'")
|
||||||
|
|
||||||
|
if parameters.get('top_p', 1) > 2:
|
||||||
|
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
|
||||||
|
if parameters.get('top_p', 1) < 0:
|
||||||
|
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
|
||||||
|
|
||||||
|
if parameters.get('presence_penalty', 1) > 2:
|
||||||
|
return format_oai_err(f"{parameters['presence_penalty']} is greater than the maximum of 2 - 'presence_penalty'")
|
||||||
|
if parameters.get('presence_penalty', 1) < -2:
|
||||||
|
return format_oai_err(f"{parameters['presence_penalty']} less than the minimum of -2 - 'presence_penalty'")
|
||||||
|
|
||||||
|
if parameters.get('top_p', 1) > 2:
|
||||||
|
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
|
||||||
|
if parameters.get('top_p', 1) < 0:
|
||||||
|
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
|
||||||
|
|
||||||
|
if parameters.get('top_p', 1) > 2:
|
||||||
|
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
|
||||||
|
if parameters.get('top_p', 1) < 0:
|
||||||
|
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
|
|
@ -2,73 +2,24 @@ import concurrent.futures
|
||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
import time
|
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from flask import jsonify, make_response
|
|
||||||
|
|
||||||
import llm_server
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.llm import get_token_count
|
from llm_server.llm import get_token_count
|
||||||
from llm_server.custom_redis import redis
|
|
||||||
|
|
||||||
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
|
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
|
||||||
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
|
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
|
||||||
|
|
||||||
|
|
||||||
def build_openai_response(prompt, response, model=None):
|
|
||||||
# Seperate the user's prompt from the context
|
|
||||||
x = prompt.split('### USER:')
|
|
||||||
if len(x) > 1:
|
|
||||||
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
|
|
||||||
|
|
||||||
# Make sure the bot doesn't put any other instructions in its response
|
|
||||||
# y = response.split('\n### ')
|
|
||||||
# if len(y) > 1:
|
|
||||||
# response = re.sub(r'\n$', '', y[0].strip(' '))
|
|
||||||
response = re.sub(ANTI_RESPONSE_RE, '', response)
|
|
||||||
response = re.sub(ANTI_CONTINUATION_RE, '', response)
|
|
||||||
|
|
||||||
# TODO: async/await
|
|
||||||
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
|
||||||
response_tokens = llm_server.llm.get_token_count(response)
|
|
||||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
|
||||||
|
|
||||||
response = make_response(jsonify({
|
|
||||||
"id": f"chatcmpl-{generate_oai_string(30)}",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": int(time.time()),
|
|
||||||
"model": running_model if opts.openai_expose_our_model else model,
|
|
||||||
"choices": [{
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": response,
|
|
||||||
},
|
|
||||||
"logprobs": None,
|
|
||||||
"finish_reason": "stop"
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": prompt_tokens,
|
|
||||||
"completion_tokens": response_tokens,
|
|
||||||
"total_tokens": prompt_tokens + response_tokens
|
|
||||||
}
|
|
||||||
}), 200)
|
|
||||||
|
|
||||||
stats = redis.get('proxy_stats', dtype=dict)
|
|
||||||
if stats:
|
|
||||||
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def generate_oai_string(length=24):
|
def generate_oai_string(length=24):
|
||||||
alphabet = string.ascii_letters + string.digits
|
alphabet = string.ascii_letters + string.digits
|
||||||
return ''.join(secrets.choice(alphabet) for i in range(length))
|
return ''.join(secrets.choice(alphabet) for i in range(length))
|
||||||
|
|
||||||
|
|
||||||
def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -> List[Dict[str, str]]:
|
def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]:
|
||||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
def get_token_count_tiktoken_thread(msg):
|
def get_token_count_tiktoken_thread(msg):
|
||||||
|
@ -95,13 +46,13 @@ def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -
|
||||||
break
|
break
|
||||||
|
|
||||||
def get_token_count_thread(msg):
|
def get_token_count_thread(msg):
|
||||||
return get_token_count(msg["content"])
|
return get_token_count(msg["content"], backend_url)
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||||
token_counts = list(executor.map(get_token_count_thread, prompt))
|
token_counts = list(executor.map(get_token_count_thread, prompt))
|
||||||
|
|
||||||
total_tokens = sum(token_counts)
|
total_tokens = sum(token_counts)
|
||||||
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt)) - total_tokens
|
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
|
||||||
|
|
||||||
if total_tokens + formatting_tokens > context_token_limit:
|
if total_tokens + formatting_tokens > context_token_limit:
|
||||||
# Start over, but this time calculate the token count using the backend
|
# Start over, but this time calculate the token count using the backend
|
||||||
|
@ -109,6 +60,40 @@ def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -
|
||||||
token_counts = list(executor.map(get_token_count_thread, prompt))
|
token_counts = list(executor.map(get_token_count_thread, prompt))
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str:
|
||||||
|
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
def get_token_count_tiktoken_thread(msg):
|
||||||
|
return len(tokenizer.encode(msg))
|
||||||
|
|
||||||
|
token_count = get_token_count_tiktoken_thread(prompt)
|
||||||
|
|
||||||
|
# If total tokens exceed the limit, start trimming
|
||||||
|
if token_count > context_token_limit:
|
||||||
|
while True:
|
||||||
|
while token_count > context_token_limit:
|
||||||
|
# Calculate the index to start removing characters from
|
||||||
|
remove_index = len(prompt) // 3
|
||||||
|
|
||||||
|
while remove_index < len(prompt):
|
||||||
|
prompt = prompt[:remove_index] + prompt[remove_index + 100:]
|
||||||
|
token_count = get_token_count_tiktoken_thread(prompt)
|
||||||
|
if token_count <= context_token_limit or remove_index == len(prompt):
|
||||||
|
break
|
||||||
|
|
||||||
|
def get_token_count_thread(msg):
|
||||||
|
return get_token_count(msg, backend_url)
|
||||||
|
|
||||||
|
token_count = get_token_count_thread(prompt)
|
||||||
|
|
||||||
|
if token_count > context_token_limit:
|
||||||
|
# Start over, but this time calculate the token count using the backend
|
||||||
|
token_count = get_token_count_thread(prompt)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import threading
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from flask import jsonify
|
from flask import jsonify
|
||||||
|
@ -35,9 +34,11 @@ class VLLMBackend(LLMBackend):
|
||||||
top_p=parameters.get('top_p', self._default_params['top_p']),
|
top_p=parameters.get('top_p', self._default_params['top_p']),
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
|
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
|
||||||
stop=parameters.get('stopping_strings', self._default_params['stop']),
|
stop=list(set(parameters.get('stopping_strings', self._default_params['stop']) or parameters.get('stop', self._default_params['stop']))),
|
||||||
ignore_eos=parameters.get('ban_eos_token', False),
|
ignore_eos=parameters.get('ban_eos_token', False),
|
||||||
max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens'])
|
max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens']) or parameters.get('max_tokens', self._default_params['max_tokens']),
|
||||||
|
presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']),
|
||||||
|
frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty'])
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return None, str(e).strip('.')
|
return None, str(e).strip('.')
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import json
|
import json
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
@ -10,11 +9,10 @@ from . import openai_bp
|
||||||
from ..helpers.http import validate_json
|
from ..helpers.http import validate_json
|
||||||
from ..openai_request_handler import OpenAIRequestHandler
|
from ..openai_request_handler import OpenAIRequestHandler
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...cluster.backend import get_a_cluster_backend
|
|
||||||
from ...database.database import log_prompt
|
from ...database.database import log_prompt
|
||||||
from ...llm.generator import generator
|
from ...llm.generator import generator
|
||||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt
|
from ...llm.openai.oai_to_vllm import oai_to_vllm
|
||||||
from ...llm.vllm import tokenize
|
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||||
|
|
||||||
|
|
||||||
# TODO: add rate-limit headers?
|
# TODO: add rate-limit headers?
|
||||||
|
@ -25,32 +23,46 @@ def openai_chat_completions():
|
||||||
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
|
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
|
||||||
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
||||||
else:
|
else:
|
||||||
handler = OpenAIRequestHandler(request, request_json_body)
|
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body)
|
||||||
if request_json_body.get('stream'):
|
|
||||||
|
if handler.cluster_backend_info['mode'] != 'vllm':
|
||||||
|
# TODO: implement other backends
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
if not request_json_body.get('stream'):
|
||||||
|
try:
|
||||||
|
return handler.handle_request()
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
return 'Internal server error', 500
|
||||||
|
else:
|
||||||
if not opts.enable_streaming:
|
if not opts.enable_streaming:
|
||||||
# TODO: return a proper OAI error message
|
# TODO: return a proper OAI error message
|
||||||
return 'disabled', 401
|
return 'disabled', 401
|
||||||
|
|
||||||
if opts.mode != 'vllm':
|
if opts.openai_silent_trim:
|
||||||
# TODO: implement other backends
|
handler.request_json_body['messages'] = trim_messages_to_fit(request_json_body['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
response_status_code = 0
|
response_status_code = 0
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
request_valid, invalid_response = handler.validate_request()
|
request_valid, invalid_response = handler.validate_request()
|
||||||
if not request_valid:
|
if not request_valid:
|
||||||
# TODO: simulate OAI here
|
return invalid_response
|
||||||
raise Exception('TODO: simulate OAI here')
|
|
||||||
else:
|
else:
|
||||||
handler.prompt = transform_messages_to_prompt(request_json_body['messages'])
|
if opts.openai_silent_trim:
|
||||||
|
oai_messages = trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
||||||
|
else:
|
||||||
|
oai_messages = handler.request.json['messages']
|
||||||
|
|
||||||
|
handler.prompt = transform_messages_to_prompt(oai_messages)
|
||||||
|
handler.parameters = oai_to_vllm(handler.parameters, hashes=True, mode=handler.cluster_backend_info['mode'])
|
||||||
msg_to_backend = {
|
msg_to_backend = {
|
||||||
**handler.parameters,
|
**handler.parameters,
|
||||||
'prompt': handler.prompt,
|
'prompt': handler.prompt,
|
||||||
'stream': True,
|
'stream': True,
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
cluster_backend = get_a_cluster_backend()
|
response = generator(msg_to_backend, handler.backend_url)
|
||||||
response = generator(msg_to_backend, cluster_backend)
|
|
||||||
r_headers = dict(request.headers)
|
r_headers = dict(request.headers)
|
||||||
r_url = request.url
|
r_url = request.url
|
||||||
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
|
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
|
||||||
|
@ -94,22 +106,20 @@ def openai_chat_completions():
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
|
|
||||||
def background_task():
|
log_prompt(
|
||||||
generated_tokens = tokenize(generated_text)
|
handler.client_ip,
|
||||||
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=generated_tokens)
|
handler.token,
|
||||||
|
handler.prompt,
|
||||||
# TODO: use async/await instead of threads
|
generated_text,
|
||||||
thread = threading.Thread(target=background_task)
|
elapsed_time,
|
||||||
thread.start()
|
handler.parameters,
|
||||||
thread.join()
|
r_headers,
|
||||||
|
response_status_code,
|
||||||
|
r_url,
|
||||||
|
handler.backend_url,
|
||||||
|
)
|
||||||
|
|
||||||
return Response(generate(), mimetype='text/event-stream')
|
return Response(generate(), mimetype='text/event-stream')
|
||||||
except:
|
except:
|
||||||
# TODO: simulate OAI here
|
# TODO: simulate OAI here
|
||||||
raise Exception
|
raise Exception
|
||||||
else:
|
|
||||||
try:
|
|
||||||
return handler.handle_request()
|
|
||||||
except Exception:
|
|
||||||
traceback.print_exc()
|
|
||||||
return 'Internal server error', 500
|
|
||||||
|
|
|
@ -1,15 +1,19 @@
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from flask import jsonify, make_response, request
|
import simplejson as json
|
||||||
|
from flask import Response, jsonify, request
|
||||||
|
|
||||||
from . import openai_bp
|
|
||||||
from llm_server.custom_redis import redis
|
from llm_server.custom_redis import redis
|
||||||
|
from . import openai_bp
|
||||||
from ..helpers.http import validate_json
|
from ..helpers.http import validate_json
|
||||||
from ..ooba_request_handler import OobaRequestHandler
|
from ..ooba_request_handler import OobaRequestHandler
|
||||||
from ... import opts
|
from ... import opts
|
||||||
|
from ...database.database import log_prompt
|
||||||
from ...llm import get_token_count
|
from ...llm import get_token_count
|
||||||
from ...llm.openai.transform import generate_oai_string
|
from ...llm.generator import generator
|
||||||
|
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
|
||||||
|
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
|
||||||
|
|
||||||
|
|
||||||
# TODO: add rate-limit headers?
|
# TODO: add rate-limit headers?
|
||||||
|
@ -21,40 +25,137 @@ def openai_completions():
|
||||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
response, status_code = OobaRequestHandler(request).handle_request()
|
handler = OobaRequestHandler(incoming_request=request)
|
||||||
if status_code != 200:
|
|
||||||
return status_code
|
|
||||||
output = response.json['results'][0]['text']
|
|
||||||
|
|
||||||
# TODO: async/await
|
if handler.cluster_backend_info['mode'] != 'vllm':
|
||||||
prompt_tokens = get_token_count(request_json_body['prompt'])
|
# TODO: implement other backends
|
||||||
response_tokens = get_token_count(output)
|
raise NotImplementedError
|
||||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
|
||||||
|
|
||||||
response = make_response(jsonify({
|
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
||||||
"id": f"cmpl-{generate_oai_string(30)}",
|
if invalid_oai_err_msg:
|
||||||
"object": "text_completion",
|
return invalid_oai_err_msg
|
||||||
"created": int(time.time()),
|
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
|
||||||
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
|
|
||||||
"choices": [
|
# Convert parameters to the selected backend type
|
||||||
{
|
if opts.openai_silent_trim:
|
||||||
"text": output,
|
handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
||||||
"index": 0,
|
else:
|
||||||
"logprobs": None,
|
# The handle_request() call below will load the prompt so we don't have
|
||||||
"finish_reason": None
|
# to do anything else here.
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not request_json_body.get('stream'):
|
||||||
|
response, status_code = handler.handle_request()
|
||||||
|
if status_code != 200:
|
||||||
|
return status_code
|
||||||
|
output = response.json['results'][0]['text']
|
||||||
|
|
||||||
|
# TODO: async/await
|
||||||
|
prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url)
|
||||||
|
response_tokens = get_token_count(output, handler.backend_url)
|
||||||
|
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||||
|
|
||||||
|
response = jsonify({
|
||||||
|
"id": f"cmpl-{generate_oai_string(30)}",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": output,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": response_tokens,
|
||||||
|
"total_tokens": prompt_tokens + response_tokens
|
||||||
}
|
}
|
||||||
],
|
})
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": prompt_tokens,
|
|
||||||
"completion_tokens": response_tokens,
|
|
||||||
"total_tokens": prompt_tokens + response_tokens
|
|
||||||
}
|
|
||||||
}), 200)
|
|
||||||
|
|
||||||
stats = redis.get('proxy_stats', dtype=dict)
|
stats = redis.get('proxy_stats', dtype=dict)
|
||||||
if stats:
|
if stats:
|
||||||
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
||||||
return response
|
return response, 200
|
||||||
|
else:
|
||||||
|
if not opts.enable_streaming:
|
||||||
|
# TODO: return a proper OAI error message
|
||||||
|
return 'disabled', 401
|
||||||
|
|
||||||
|
response_status_code = 0
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
request_valid, invalid_response = handler.validate_request()
|
||||||
|
if not request_valid:
|
||||||
|
# TODO: simulate OAI here
|
||||||
|
raise Exception('TODO: simulate OAI here')
|
||||||
|
else:
|
||||||
|
handler.prompt = handler.request_json_body['prompt']
|
||||||
|
msg_to_backend = {
|
||||||
|
**handler.parameters,
|
||||||
|
'prompt': handler.prompt,
|
||||||
|
'stream': True,
|
||||||
|
}
|
||||||
|
response = generator(msg_to_backend, handler.backend_url)
|
||||||
|
r_headers = dict(request.headers)
|
||||||
|
r_url = request.url
|
||||||
|
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
|
||||||
|
oai_string = generate_oai_string(30)
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
generated_text = ''
|
||||||
|
partial_response = b''
|
||||||
|
for chunk in response.iter_content(chunk_size=1):
|
||||||
|
partial_response += chunk
|
||||||
|
if partial_response.endswith(b'\x00'):
|
||||||
|
json_strs = partial_response.split(b'\x00')
|
||||||
|
for json_str in json_strs:
|
||||||
|
if json_str:
|
||||||
|
try:
|
||||||
|
json_obj = json.loads(json_str.decode())
|
||||||
|
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
||||||
|
generated_text = generated_text + new
|
||||||
|
except IndexError:
|
||||||
|
# ????
|
||||||
|
continue
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"id": f"chatcmpl-{oai_string}",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"content": new
|
||||||
|
},
|
||||||
|
"finish_reason": None
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
yield f'data: {json.dumps(data)}\n\n'
|
||||||
|
|
||||||
|
yield 'data: [DONE]\n\n'
|
||||||
|
end_time = time.time()
|
||||||
|
elapsed_time = end_time - start_time
|
||||||
|
|
||||||
|
log_prompt(
|
||||||
|
handler.client_ip,
|
||||||
|
handler.token,
|
||||||
|
handler.prompt,
|
||||||
|
generated_text,
|
||||||
|
elapsed_time,
|
||||||
|
handler.parameters,
|
||||||
|
r_headers,
|
||||||
|
response_status_code,
|
||||||
|
r_url,
|
||||||
|
handler.backend_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Response(generate(), mimetype='text/event-stream')
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return 'Internal Server Error', 500
|
return 'Internal Server Error', 500
|
||||||
|
|
|
@ -3,24 +3,24 @@ import traceback
|
||||||
import requests
|
import requests
|
||||||
from flask import jsonify
|
from flask import jsonify
|
||||||
|
|
||||||
from . import openai_bp
|
|
||||||
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis
|
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis
|
||||||
|
from . import openai_bp
|
||||||
from ..stats import server_start_time
|
from ..stats import server_start_time
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...cluster.backend import get_a_cluster_backend
|
from ...cluster.backend import get_a_cluster_backend
|
||||||
|
from ...cluster.cluster_config import cluster_config
|
||||||
from ...helpers import jsonify_pretty
|
from ...helpers import jsonify_pretty
|
||||||
from ...llm.info import get_running_model
|
from ...llm.openai.transform import generate_oai_string
|
||||||
|
|
||||||
|
|
||||||
@openai_bp.route('/models', methods=['GET'])
|
@openai_bp.route('/models', methods=['GET'])
|
||||||
@flask_cache.cached(timeout=60, query_string=True)
|
@flask_cache.cached(timeout=60, query_string=True)
|
||||||
def openai_list_models():
|
def openai_list_models():
|
||||||
model, error = get_running_model()
|
model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model')
|
||||||
if not model:
|
if not model_name:
|
||||||
response = jsonify({
|
response = jsonify({
|
||||||
'code': 502,
|
'code': 502,
|
||||||
'msg': 'failed to reach backend',
|
'msg': 'failed to reach backend',
|
||||||
'type': error.__class__.__name__
|
|
||||||
}), 500 # return 500 so Cloudflare doesn't intercept us
|
}), 500 # return 500 so Cloudflare doesn't intercept us
|
||||||
else:
|
else:
|
||||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||||
|
@ -65,7 +65,14 @@ def fetch_openai_models():
|
||||||
if opts.openai_api_key:
|
if opts.openai_api_key:
|
||||||
try:
|
try:
|
||||||
response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10)
|
response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10)
|
||||||
return response.json()['data']
|
j = response.json()['data']
|
||||||
|
|
||||||
|
# The "modelperm" string appears to be user-specific, so we'll
|
||||||
|
# randomize it just to be safe.
|
||||||
|
for model in range(len(j)):
|
||||||
|
for p in range(len(j[model]['permission'])):
|
||||||
|
j[model]['permission'][p]['id'] = f'modelperm-{generate_oai_string(24)}'
|
||||||
|
return j
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return []
|
return []
|
||||||
|
|
|
@ -17,7 +17,7 @@ def openai_organizations():
|
||||||
"id": f"org-{generate_oai_string(24)}",
|
"id": f"org-{generate_oai_string(24)}",
|
||||||
"created": int(server_start_time.timestamp()),
|
"created": int(server_start_time.timestamp()),
|
||||||
"title": "Personal",
|
"title": "Personal",
|
||||||
"name": "user-abcdefghijklmnopqrstuvwx",
|
"name": f"user-{generate_oai_string(24)}",
|
||||||
"description": "Personal org for bobjoe@0.0.0.0",
|
"description": "Personal org for bobjoe@0.0.0.0",
|
||||||
"personal": True,
|
"personal": True,
|
||||||
"is_default": True,
|
"is_default": True,
|
||||||
|
|
|
@ -1,14 +1,19 @@
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import flask
|
import flask
|
||||||
from flask import jsonify
|
from flask import Response, jsonify, make_response
|
||||||
|
|
||||||
|
import llm_server
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
|
from llm_server.custom_redis import redis
|
||||||
from llm_server.database.database import is_api_key_moderated
|
from llm_server.database.database import is_api_key_moderated
|
||||||
from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit
|
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
|
||||||
|
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||||
from llm_server.routes.request_handler import RequestHandler
|
from llm_server.routes.request_handler import RequestHandler
|
||||||
from llm_server.workers.moderator import add_moderation_task, get_results
|
from llm_server.workers.moderator import add_moderation_task, get_results
|
||||||
|
|
||||||
|
@ -22,7 +27,7 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
assert not self.used
|
assert not self.used
|
||||||
|
|
||||||
if opts.openai_silent_trim:
|
if opts.openai_silent_trim:
|
||||||
oai_messages = trim_prompt_to_fit(self.request.json['messages'], opts.context_size)
|
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
|
||||||
else:
|
else:
|
||||||
oai_messages = self.request.json['messages']
|
oai_messages = self.request.json['messages']
|
||||||
|
|
||||||
|
@ -51,13 +56,8 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
|
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
# Reconstruct the request JSON with the validated parameters and prompt.
|
# TODO: support Ooba
|
||||||
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
|
self.parameters = oai_to_vllm(self.parameters, hashes=True, mode=self.cluster_backend_info['mode'])
|
||||||
if opts.openai_force_no_hashes:
|
|
||||||
self.parameters['stop'].append('### ')
|
|
||||||
|
|
||||||
if opts.mode == 'vllm' and self.request_json_body.get('top_p') == 0:
|
|
||||||
self.request_json_body['top_p'] = 0.01
|
|
||||||
|
|
||||||
llm_request = {**self.parameters, 'prompt': self.prompt}
|
llm_request = {**self.parameters, 'prompt': self.prompt}
|
||||||
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
||||||
|
@ -65,7 +65,7 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
model = self.request_json_body.get('model')
|
model = self.request_json_body.get('model')
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
|
return self.build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
|
||||||
else:
|
else:
|
||||||
return backend_response, backend_response_status_code
|
return backend_response, backend_response_status_code
|
||||||
|
|
||||||
|
@ -75,7 +75,6 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
return 'Ratelimited', 429
|
return 'Ratelimited', 429
|
||||||
|
|
||||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||||
# TODO: return a simulated OpenAI error message
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"error": {
|
"error": {
|
||||||
"message": "Invalid request, check your parameters and try again.",
|
"message": "Invalid request, check your parameters and try again.",
|
||||||
|
@ -84,3 +83,52 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
"code": None
|
"code": None
|
||||||
}
|
}
|
||||||
}), 400
|
}), 400
|
||||||
|
|
||||||
|
def build_openai_response(self, prompt, response, model=None):
|
||||||
|
# Seperate the user's prompt from the context
|
||||||
|
x = prompt.split('### USER:')
|
||||||
|
if len(x) > 1:
|
||||||
|
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
|
||||||
|
|
||||||
|
# Make sure the bot doesn't put any other instructions in its response
|
||||||
|
response = re.sub(ANTI_RESPONSE_RE, '', response)
|
||||||
|
response = re.sub(ANTI_CONTINUATION_RE, '', response)
|
||||||
|
|
||||||
|
# TODO: async/await
|
||||||
|
prompt_tokens = llm_server.llm.get_token_count(prompt, self.backend_url)
|
||||||
|
response_tokens = llm_server.llm.get_token_count(response, self.backend_url)
|
||||||
|
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||||
|
|
||||||
|
response = make_response(jsonify({
|
||||||
|
"id": f"chatcmpl-{generate_oai_string(30)}",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": running_model if opts.openai_expose_our_model else model,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response,
|
||||||
|
},
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": response_tokens,
|
||||||
|
"total_tokens": prompt_tokens + response_tokens
|
||||||
|
}
|
||||||
|
}), 200)
|
||||||
|
|
||||||
|
stats = redis.get('proxy_stats', dtype=dict)
|
||||||
|
if stats:
|
||||||
|
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
||||||
|
return response
|
||||||
|
|
||||||
|
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
|
||||||
|
invalid_oai_err_msg = validate_oai(self.request_json_body)
|
||||||
|
if invalid_oai_err_msg:
|
||||||
|
return False, invalid_oai_err_msg
|
||||||
|
self.request_json_body = oai_to_vllm(self.request_json_body, hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||||
|
# If the parameters were invalid, let the superclass deal with it.
|
||||||
|
return super().validate_request(prompt, do_log)
|
||||||
|
|
|
@ -15,13 +15,13 @@ from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
||||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||||
from llm_server.routes.auth import parse_token
|
from llm_server.routes.auth import parse_token
|
||||||
from llm_server.routes.helpers.http import require_api_key, validate_json
|
from llm_server.routes.helpers.http import require_api_key, validate_json
|
||||||
from llm_server.routes.queue import RedisPriorityQueue, priority_queue
|
from llm_server.routes.queue import priority_queue
|
||||||
|
|
||||||
DEFAULT_PRIORITY = 9999
|
DEFAULT_PRIORITY = 9999
|
||||||
|
|
||||||
|
|
||||||
class RequestHandler:
|
class RequestHandler:
|
||||||
def __init__(self, incoming_request: flask.Request, selected_model: str, incoming_json: Union[dict, str] = None):
|
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
|
||||||
self.request = incoming_request
|
self.request = incoming_request
|
||||||
self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
|
self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ class RequestHandler:
|
||||||
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
|
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
|
||||||
|
|
||||||
if not self.cluster_backend_info.get('mode'):
|
if not self.cluster_backend_info.get('mode'):
|
||||||
print(self.backend_url, self.cluster_backend_info)
|
print(selected_model, self.backend_url, self.cluster_backend_info)
|
||||||
|
|
||||||
self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url)
|
self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url)
|
||||||
self.parameters = None
|
self.parameters = None
|
||||||
|
|
|
@ -5,8 +5,6 @@ from flask import jsonify, request
|
||||||
from . import bp
|
from . import bp
|
||||||
from ..helpers.http import validate_json
|
from ..helpers.http import validate_json
|
||||||
from ..ooba_request_handler import OobaRequestHandler
|
from ..ooba_request_handler import OobaRequestHandler
|
||||||
from ...cluster.backend import get_a_cluster_backend
|
|
||||||
from ...cluster.cluster_config import cluster_config
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/v1/generate', methods=['POST'])
|
@bp.route('/v1/generate', methods=['POST'])
|
||||||
|
|
|
@ -71,6 +71,7 @@ def generate_stats(regen: bool = False):
|
||||||
'model': backend_info['model'],
|
'model': backend_info['model'],
|
||||||
'mode': backend_info['mode'],
|
'mode': backend_info['mode'],
|
||||||
'nvidia': backend_info['nvidia'],
|
'nvidia': backend_info['nvidia'],
|
||||||
|
'priority': backend_info['priority'],
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
output['backend_info'] = {}
|
output['backend_info'] = {}
|
||||||
|
|
|
@ -84,7 +84,7 @@ def do_stream(ws, model_name):
|
||||||
ws.close()
|
ws.close()
|
||||||
return auth_failure
|
return auth_failure
|
||||||
|
|
||||||
handler = OobaRequestHandler(request, model_name, request_json_body)
|
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
|
||||||
generated_text = ''
|
generated_text = ''
|
||||||
input_prompt = request_json_body['prompt']
|
input_prompt = request_json_body['prompt']
|
||||||
response_status_code = 0
|
response_status_code = 0
|
||||||
|
|
|
@ -4,10 +4,8 @@ from flask import jsonify, request
|
||||||
|
|
||||||
from llm_server.custom_redis import flask_cache
|
from llm_server.custom_redis import flask_cache
|
||||||
from . import bp
|
from . import bp
|
||||||
from ..auth import requires_auth
|
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...cluster.backend import get_a_cluster_backend, get_backends, get_backends_from_model, is_valid_model
|
from ...cluster.backend import get_a_cluster_backend, get_backends_from_model, is_valid_model
|
||||||
from ...cluster.cluster_config import cluster_config
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/v1/model', methods=['GET'])
|
@bp.route('/v1/model', methods=['GET'])
|
||||||
|
@ -39,14 +37,3 @@ def get_model(model_name=None):
|
||||||
flask_cache.set(cache_key, response, timeout=60)
|
flask_cache.set(cache_key, response, timeout=60)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/backends', methods=['GET'])
|
|
||||||
@requires_auth
|
|
||||||
def get_backend():
|
|
||||||
online, offline = get_backends()
|
|
||||||
result = {}
|
|
||||||
for i in online + offline:
|
|
||||||
info = cluster_config.get_backend(i)
|
|
||||||
result[info['hash']] = info
|
|
||||||
return jsonify(result), 200
|
|
||||||
|
|
|
@ -1,6 +1,11 @@
|
||||||
|
from flask import jsonify
|
||||||
|
|
||||||
|
from llm_server.custom_redis import flask_cache
|
||||||
from . import bp
|
from . import bp
|
||||||
from .generate_stats import generate_stats
|
from .generate_stats import generate_stats
|
||||||
from llm_server.custom_redis import flask_cache
|
from ..auth import requires_auth
|
||||||
|
from ...cluster.backend import get_backends
|
||||||
|
from ...cluster.cluster_config import cluster_config
|
||||||
from ...helpers import jsonify_pretty
|
from ...helpers import jsonify_pretty
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,3 +13,14 @@ from ...helpers import jsonify_pretty
|
||||||
@flask_cache.cached(timeout=5, query_string=True)
|
@flask_cache.cached(timeout=5, query_string=True)
|
||||||
def get_stats():
|
def get_stats():
|
||||||
return jsonify_pretty(generate_stats())
|
return jsonify_pretty(generate_stats())
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/backends', methods=['GET'])
|
||||||
|
@requires_auth
|
||||||
|
def get_backend():
|
||||||
|
online, offline = get_backends()
|
||||||
|
result = {}
|
||||||
|
for i in online + offline:
|
||||||
|
info = cluster_config.get_backend(i)
|
||||||
|
result[info['hash']] = info
|
||||||
|
return jsonify(result), 200
|
||||||
|
|
|
@ -24,11 +24,13 @@ from llm_server.routes.server_error import handle_server_error
|
||||||
from llm_server.routes.v1 import bp
|
from llm_server.routes.v1 import bp
|
||||||
from llm_server.sock import init_socketio
|
from llm_server.sock import init_socketio
|
||||||
|
|
||||||
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
|
# TODO: make sure openai_moderation_enabled works on websockets, completions, and chat completions
|
||||||
# TODO: need to update opts. for workers
|
|
||||||
# TODO: add a healthcheck to VLLM
|
|
||||||
|
|
||||||
# Lower priority
|
# Lower priority
|
||||||
|
# TODO: support logit_bias on OpenAI and Ooba endpoints.
|
||||||
|
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
|
||||||
|
# TODO: validate openai_silent_trim works as expected and only when enabled
|
||||||
|
# TODO: rewrite config storage. Store in redis so we can reload it.
|
||||||
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
|
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
|
||||||
# TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens
|
# TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens
|
||||||
# TODO: the estiamted wait time lags behind the stats
|
# TODO: the estiamted wait time lags behind the stats
|
||||||
|
|
Reference in New Issue