update openai endpoints

This commit is contained in:
Cyberes 2023-10-01 14:15:01 -06:00
parent 93d19fb95b
commit 2a3ff7e21e
18 changed files with 384 additions and 161 deletions

View File

@ -43,6 +43,8 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database.
### Use
If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis.
Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app`

View File

@ -60,7 +60,8 @@ def get_model_choices(regen: bool = False):
if len(context_size):
model_choices[model]['context_size'] = min(context_size)
model_choices = dict(sorted(model_choices.items()))
# Python wants to sort lowercase vs. uppercase letters differently.
model_choices = dict(sorted(model_choices.items(), key=lambda item: item[0].upper()))
default_backend = get_a_cluster_backend()
default_backend_dict = {}

View File

@ -2,7 +2,6 @@ from typing import Tuple, Union
import flask
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.llm import get_token_count
@ -36,6 +35,8 @@ class LLMBackend:
"""
If a backend needs to do other checks not related to the prompt or parameters.
Default is no extra checks preformed.
:param request:
:param prompt:
:param parameters:
:return:
"""

View File

@ -0,0 +1,63 @@
from flask import jsonify
from llm_server import opts
def oai_to_vllm(request_json_body, hashes: bool, mode):
if not request_json_body.get('stop'):
request_json_body['stop'] = []
if hashes:
request_json_body['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
if opts.openai_force_no_hashes:
request_json_body['stop'].append('### ')
else:
request_json_body['stop'].extend(['\nuser:', '\nassistant:'])
if request_json_body.get('frequency_penalty', 0) < -2:
request_json_body['frequency_penalty'] = -2
elif request_json_body.get('frequency_penalty', 0) > 2:
request_json_body['frequency_penalty'] = 2
if mode == 'vllm' and request_json_body.get('top_p') == 0:
request_json_body['top_p'] = 0.01
return request_json_body
def format_oai_err(err_msg):
return jsonify({
"error": {
"message": err_msg,
"type": "invalid_request_error",
"param": None,
"code": None
}
}), 400
def validate_oai(parameters):
if parameters['temperature'] > 2:
return format_oai_err(f"{parameters['temperature']} is greater than the maximum of 2 - 'temperature'")
if parameters['temperature'] < 0:
return format_oai_err(f"{parameters['temperature']} less than the minimum of 0 - 'temperature'")
if parameters.get('top_p', 1) > 2:
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
if parameters.get('top_p', 1) < 0:
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
if parameters.get('presence_penalty', 1) > 2:
return format_oai_err(f"{parameters['presence_penalty']} is greater than the maximum of 2 - 'presence_penalty'")
if parameters.get('presence_penalty', 1) < -2:
return format_oai_err(f"{parameters['presence_penalty']} less than the minimum of -2 - 'presence_penalty'")
if parameters.get('top_p', 1) > 2:
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
if parameters.get('top_p', 1) < 0:
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
if parameters.get('top_p', 1) > 2:
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
if parameters.get('top_p', 1) < 0:
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")

View File

@ -2,73 +2,24 @@ import concurrent.futures
import re
import secrets
import string
import time
import traceback
from typing import Dict, List
import tiktoken
from flask import jsonify, make_response
import llm_server
from llm_server import opts
from llm_server.llm import get_token_count
from llm_server.custom_redis import redis
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
def build_openai_response(prompt, response, model=None):
# Seperate the user's prompt from the context
x = prompt.split('### USER:')
if len(x) > 1:
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
# Make sure the bot doesn't put any other instructions in its response
# y = response.split('\n### ')
# if len(y) > 1:
# response = re.sub(r'\n$', '', y[0].strip(' '))
response = re.sub(ANTI_RESPONSE_RE, '', response)
response = re.sub(ANTI_CONTINUATION_RE, '', response)
# TODO: async/await
prompt_tokens = llm_server.llm.get_token_count(prompt)
response_tokens = llm_server.llm.get_token_count(response)
running_model = redis.get('running_model', 'ERROR', dtype=str)
response = make_response(jsonify({
"id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response,
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
}), 200)
stats = redis.get('proxy_stats', dtype=dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response
def generate_oai_string(length=24):
alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for i in range(length))
def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -> List[Dict[str, str]]:
def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]:
tokenizer = tiktoken.get_encoding("cl100k_base")
def get_token_count_tiktoken_thread(msg):
@ -95,13 +46,13 @@ def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -
break
def get_token_count_thread(msg):
return get_token_count(msg["content"])
return get_token_count(msg["content"], backend_url)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_thread, prompt))
total_tokens = sum(token_counts)
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt)) - total_tokens
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
if total_tokens + formatting_tokens > context_token_limit:
# Start over, but this time calculate the token count using the backend
@ -109,6 +60,40 @@ def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -
token_counts = list(executor.map(get_token_count_thread, prompt))
else:
break
return prompt
def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str:
tokenizer = tiktoken.get_encoding("cl100k_base")
def get_token_count_tiktoken_thread(msg):
return len(tokenizer.encode(msg))
token_count = get_token_count_tiktoken_thread(prompt)
# If total tokens exceed the limit, start trimming
if token_count > context_token_limit:
while True:
while token_count > context_token_limit:
# Calculate the index to start removing characters from
remove_index = len(prompt) // 3
while remove_index < len(prompt):
prompt = prompt[:remove_index] + prompt[remove_index + 100:]
token_count = get_token_count_tiktoken_thread(prompt)
if token_count <= context_token_limit or remove_index == len(prompt):
break
def get_token_count_thread(msg):
return get_token_count(msg, backend_url)
token_count = get_token_count_thread(prompt)
if token_count > context_token_limit:
# Start over, but this time calculate the token count using the backend
token_count = get_token_count_thread(prompt)
else:
break
return prompt

View File

@ -1,4 +1,3 @@
import threading
from typing import Tuple
from flask import jsonify
@ -35,9 +34,11 @@ class VLLMBackend(LLMBackend):
top_p=parameters.get('top_p', self._default_params['top_p']),
top_k=top_k,
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
stop=parameters.get('stopping_strings', self._default_params['stop']),
stop=list(set(parameters.get('stopping_strings', self._default_params['stop']) or parameters.get('stop', self._default_params['stop']))),
ignore_eos=parameters.get('ban_eos_token', False),
max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens'])
max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens']) or parameters.get('max_tokens', self._default_params['max_tokens']),
presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']),
frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty'])
)
except ValueError as e:
return None, str(e).strip('.')

View File

@ -1,5 +1,4 @@
import json
import threading
import time
import traceback
@ -10,11 +9,10 @@ from . import openai_bp
from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler
from ... import opts
from ...cluster.backend import get_a_cluster_backend
from ...database.database import log_prompt
from ...llm.generator import generator
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt
from ...llm.vllm import tokenize
from ...llm.openai.oai_to_vllm import oai_to_vllm
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
# TODO: add rate-limit headers?
@ -25,32 +23,46 @@ def openai_chat_completions():
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
else:
handler = OpenAIRequestHandler(request, request_json_body)
if request_json_body.get('stream'):
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body)
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends
raise NotImplementedError
if not request_json_body.get('stream'):
try:
return handler.handle_request()
except Exception:
traceback.print_exc()
return 'Internal server error', 500
else:
if not opts.enable_streaming:
# TODO: return a proper OAI error message
return 'disabled', 401
if opts.mode != 'vllm':
# TODO: implement other backends
raise NotImplementedError
if opts.openai_silent_trim:
handler.request_json_body['messages'] = trim_messages_to_fit(request_json_body['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
response_status_code = 0
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
# TODO: simulate OAI here
raise Exception('TODO: simulate OAI here')
return invalid_response
else:
handler.prompt = transform_messages_to_prompt(request_json_body['messages'])
if opts.openai_silent_trim:
oai_messages = trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
else:
oai_messages = handler.request.json['messages']
handler.prompt = transform_messages_to_prompt(oai_messages)
handler.parameters = oai_to_vllm(handler.parameters, hashes=True, mode=handler.cluster_backend_info['mode'])
msg_to_backend = {
**handler.parameters,
'prompt': handler.prompt,
'stream': True,
}
try:
cluster_backend = get_a_cluster_backend()
response = generator(msg_to_backend, cluster_backend)
response = generator(msg_to_backend, handler.backend_url)
r_headers = dict(request.headers)
r_url = request.url
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
@ -94,22 +106,20 @@ def openai_chat_completions():
end_time = time.time()
elapsed_time = end_time - start_time
def background_task():
generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=generated_tokens)
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task)
thread.start()
thread.join()
log_prompt(
handler.client_ip,
handler.token,
handler.prompt,
generated_text,
elapsed_time,
handler.parameters,
r_headers,
response_status_code,
r_url,
handler.backend_url,
)
return Response(generate(), mimetype='text/event-stream')
except:
# TODO: simulate OAI here
raise Exception
else:
try:
return handler.handle_request()
except Exception:
traceback.print_exc()
return 'Internal server error', 500

View File

@ -1,15 +1,19 @@
import time
import traceback
from flask import jsonify, make_response, request
import simplejson as json
from flask import Response, jsonify, request
from . import openai_bp
from llm_server.custom_redis import redis
from . import openai_bp
from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler
from ... import opts
from ...database.database import log_prompt
from ...llm import get_token_count
from ...llm.openai.transform import generate_oai_string
from ...llm.generator import generator
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
# TODO: add rate-limit headers?
@ -21,17 +25,37 @@ def openai_completions():
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else:
try:
response, status_code = OobaRequestHandler(request).handle_request()
handler = OobaRequestHandler(incoming_request=request)
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends
raise NotImplementedError
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
# Convert parameters to the selected backend type
if opts.openai_silent_trim:
handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
else:
# The handle_request() call below will load the prompt so we don't have
# to do anything else here.
pass
if not request_json_body.get('stream'):
response, status_code = handler.handle_request()
if status_code != 200:
return status_code
output = response.json['results'][0]['text']
# TODO: async/await
prompt_tokens = get_token_count(request_json_body['prompt'])
response_tokens = get_token_count(output)
prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url)
response_tokens = get_token_count(output, handler.backend_url)
running_model = redis.get('running_model', 'ERROR', dtype=str)
response = make_response(jsonify({
response = jsonify({
"id": f"cmpl-{generate_oai_string(30)}",
"object": "text_completion",
"created": int(time.time()),
@ -41,7 +65,7 @@ def openai_completions():
"text": output,
"index": 0,
"logprobs": None,
"finish_reason": None
"finish_reason": "stop"
}
],
"usage": {
@ -49,12 +73,89 @@ def openai_completions():
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
}), 200)
})
stats = redis.get('proxy_stats', dtype=dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response
return response, 200
else:
if not opts.enable_streaming:
# TODO: return a proper OAI error message
return 'disabled', 401
response_status_code = 0
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
# TODO: simulate OAI here
raise Exception('TODO: simulate OAI here')
else:
handler.prompt = handler.request_json_body['prompt']
msg_to_backend = {
**handler.parameters,
'prompt': handler.prompt,
'stream': True,
}
response = generator(msg_to_backend, handler.backend_url)
r_headers = dict(request.headers)
r_url = request.url
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30)
def generate():
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
data = {
"id": f"chatcmpl-{oai_string}",
"object": "text_completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": new
},
"finish_reason": None
}
]
}
yield f'data: {json.dumps(data)}\n\n'
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(
handler.client_ip,
handler.token,
handler.prompt,
generated_text,
elapsed_time,
handler.parameters,
r_headers,
response_status_code,
r_url,
handler.backend_url,
)
return Response(generate(), mimetype='text/event-stream')
except Exception:
traceback.print_exc()
return 'Internal Server Error', 500

View File

@ -3,24 +3,24 @@ import traceback
import requests
from flask import jsonify
from . import openai_bp
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis
from . import openai_bp
from ..stats import server_start_time
from ... import opts
from ...cluster.backend import get_a_cluster_backend
from ...cluster.cluster_config import cluster_config
from ...helpers import jsonify_pretty
from ...llm.info import get_running_model
from ...llm.openai.transform import generate_oai_string
@openai_bp.route('/models', methods=['GET'])
@flask_cache.cached(timeout=60, query_string=True)
def openai_list_models():
model, error = get_running_model()
if not model:
model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model')
if not model_name:
response = jsonify({
'code': 502,
'msg': 'failed to reach backend',
'type': error.__class__.__name__
}), 500 # return 500 so Cloudflare doesn't intercept us
else:
running_model = redis.get('running_model', 'ERROR', dtype=str)
@ -65,7 +65,14 @@ def fetch_openai_models():
if opts.openai_api_key:
try:
response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10)
return response.json()['data']
j = response.json()['data']
# The "modelperm" string appears to be user-specific, so we'll
# randomize it just to be safe.
for model in range(len(j)):
for p in range(len(j[model]['permission'])):
j[model]['permission'][p]['id'] = f'modelperm-{generate_oai_string(24)}'
return j
except:
traceback.print_exc()
return []

View File

@ -17,7 +17,7 @@ def openai_organizations():
"id": f"org-{generate_oai_string(24)}",
"created": int(server_start_time.timestamp()),
"title": "Personal",
"name": "user-abcdefghijklmnopqrstuvwx",
"name": f"user-{generate_oai_string(24)}",
"description": "Personal org for bobjoe@0.0.0.0",
"personal": True,
"is_default": True,

View File

@ -1,14 +1,19 @@
import json
import re
import time
import traceback
from typing import Tuple
from uuid import uuid4
import flask
from flask import jsonify
from flask import Response, jsonify, make_response
import llm_server
from llm_server import opts
from llm_server.custom_redis import redis
from llm_server.database.database import is_api_key_moderated
from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from llm_server.routes.request_handler import RequestHandler
from llm_server.workers.moderator import add_moderation_task, get_results
@ -22,7 +27,7 @@ class OpenAIRequestHandler(RequestHandler):
assert not self.used
if opts.openai_silent_trim:
oai_messages = trim_prompt_to_fit(self.request.json['messages'], opts.context_size)
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
else:
oai_messages = self.request.json['messages']
@ -51,13 +56,8 @@ class OpenAIRequestHandler(RequestHandler):
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
print(traceback.format_exc())
# Reconstruct the request JSON with the validated parameters and prompt.
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
if opts.openai_force_no_hashes:
self.parameters['stop'].append('### ')
if opts.mode == 'vllm' and self.request_json_body.get('top_p') == 0:
self.request_json_body['top_p'] = 0.01
# TODO: support Ooba
self.parameters = oai_to_vllm(self.parameters, hashes=True, mode=self.cluster_backend_info['mode'])
llm_request = {**self.parameters, 'prompt': self.prompt}
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
@ -65,7 +65,7 @@ class OpenAIRequestHandler(RequestHandler):
model = self.request_json_body.get('model')
if success:
return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
return self.build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
else:
return backend_response, backend_response_status_code
@ -75,7 +75,6 @@ class OpenAIRequestHandler(RequestHandler):
return 'Ratelimited', 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
# TODO: return a simulated OpenAI error message
return jsonify({
"error": {
"message": "Invalid request, check your parameters and try again.",
@ -84,3 +83,52 @@ class OpenAIRequestHandler(RequestHandler):
"code": None
}
}), 400
def build_openai_response(self, prompt, response, model=None):
# Seperate the user's prompt from the context
x = prompt.split('### USER:')
if len(x) > 1:
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
# Make sure the bot doesn't put any other instructions in its response
response = re.sub(ANTI_RESPONSE_RE, '', response)
response = re.sub(ANTI_CONTINUATION_RE, '', response)
# TODO: async/await
prompt_tokens = llm_server.llm.get_token_count(prompt, self.backend_url)
response_tokens = llm_server.llm.get_token_count(response, self.backend_url)
running_model = redis.get('running_model', 'ERROR', dtype=str)
response = make_response(jsonify({
"id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response,
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
}), 200)
stats = redis.get('proxy_stats', dtype=dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
invalid_oai_err_msg = validate_oai(self.request_json_body)
if invalid_oai_err_msg:
return False, invalid_oai_err_msg
self.request_json_body = oai_to_vllm(self.request_json_body, hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
# If the parameters were invalid, let the superclass deal with it.
return super().validate_request(prompt, do_log)

View File

@ -15,13 +15,13 @@ from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.auth import parse_token
from llm_server.routes.helpers.http import require_api_key, validate_json
from llm_server.routes.queue import RedisPriorityQueue, priority_queue
from llm_server.routes.queue import priority_queue
DEFAULT_PRIORITY = 9999
class RequestHandler:
def __init__(self, incoming_request: flask.Request, selected_model: str, incoming_json: Union[dict, str] = None):
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
self.request = incoming_request
self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
@ -41,7 +41,7 @@ class RequestHandler:
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
if not self.cluster_backend_info.get('mode'):
print(self.backend_url, self.cluster_backend_info)
print(selected_model, self.backend_url, self.cluster_backend_info)
self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url)
self.parameters = None

View File

@ -5,8 +5,6 @@ from flask import jsonify, request
from . import bp
from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler
from ...cluster.backend import get_a_cluster_backend
from ...cluster.cluster_config import cluster_config
@bp.route('/v1/generate', methods=['POST'])

View File

@ -71,6 +71,7 @@ def generate_stats(regen: bool = False):
'model': backend_info['model'],
'mode': backend_info['mode'],
'nvidia': backend_info['nvidia'],
'priority': backend_info['priority'],
}
else:
output['backend_info'] = {}

View File

@ -84,7 +84,7 @@ def do_stream(ws, model_name):
ws.close()
return auth_failure
handler = OobaRequestHandler(request, model_name, request_json_body)
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
generated_text = ''
input_prompt = request_json_body['prompt']
response_status_code = 0

View File

@ -4,10 +4,8 @@ from flask import jsonify, request
from llm_server.custom_redis import flask_cache
from . import bp
from ..auth import requires_auth
from ... import opts
from ...cluster.backend import get_a_cluster_backend, get_backends, get_backends_from_model, is_valid_model
from ...cluster.cluster_config import cluster_config
from ...cluster.backend import get_a_cluster_backend, get_backends_from_model, is_valid_model
@bp.route('/v1/model', methods=['GET'])
@ -39,14 +37,3 @@ def get_model(model_name=None):
flask_cache.set(cache_key, response, timeout=60)
return response
@bp.route('/backends', methods=['GET'])
@requires_auth
def get_backend():
online, offline = get_backends()
result = {}
for i in online + offline:
info = cluster_config.get_backend(i)
result[info['hash']] = info
return jsonify(result), 200

View File

@ -1,6 +1,11 @@
from flask import jsonify
from llm_server.custom_redis import flask_cache
from . import bp
from .generate_stats import generate_stats
from llm_server.custom_redis import flask_cache
from ..auth import requires_auth
from ...cluster.backend import get_backends
from ...cluster.cluster_config import cluster_config
from ...helpers import jsonify_pretty
@ -8,3 +13,14 @@ from ...helpers import jsonify_pretty
@flask_cache.cached(timeout=5, query_string=True)
def get_stats():
return jsonify_pretty(generate_stats())
@bp.route('/backends', methods=['GET'])
@requires_auth
def get_backend():
online, offline = get_backends()
result = {}
for i in online + offline:
info = cluster_config.get_backend(i)
result[info['hash']] = info
return jsonify(result), 200

View File

@ -24,11 +24,13 @@ from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp
from llm_server.sock import init_socketio
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
# TODO: need to update opts. for workers
# TODO: add a healthcheck to VLLM
# TODO: make sure openai_moderation_enabled works on websockets, completions, and chat completions
# Lower priority
# TODO: support logit_bias on OpenAI and Ooba endpoints.
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
# TODO: validate openai_silent_trim works as expected and only when enabled
# TODO: rewrite config storage. Store in redis so we can reload it.
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
# TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens
# TODO: the estiamted wait time lags behind the stats