update openai endpoints

This commit is contained in:
Cyberes 2023-10-01 14:15:01 -06:00
parent 93d19fb95b
commit 2a3ff7e21e
18 changed files with 384 additions and 161 deletions

View File

@ -43,6 +43,8 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database.
### Use ### Use
If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis.
Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app` Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app`

View File

@ -60,7 +60,8 @@ def get_model_choices(regen: bool = False):
if len(context_size): if len(context_size):
model_choices[model]['context_size'] = min(context_size) model_choices[model]['context_size'] = min(context_size)
model_choices = dict(sorted(model_choices.items())) # Python wants to sort lowercase vs. uppercase letters differently.
model_choices = dict(sorted(model_choices.items(), key=lambda item: item[0].upper()))
default_backend = get_a_cluster_backend() default_backend = get_a_cluster_backend()
default_backend_dict = {} default_backend_dict = {}

View File

@ -2,7 +2,6 @@ from typing import Tuple, Union
import flask import flask
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
@ -36,6 +35,8 @@ class LLMBackend:
""" """
If a backend needs to do other checks not related to the prompt or parameters. If a backend needs to do other checks not related to the prompt or parameters.
Default is no extra checks preformed. Default is no extra checks preformed.
:param request:
:param prompt:
:param parameters: :param parameters:
:return: :return:
""" """

View File

@ -0,0 +1,63 @@
from flask import jsonify
from llm_server import opts
def oai_to_vllm(request_json_body, hashes: bool, mode):
if not request_json_body.get('stop'):
request_json_body['stop'] = []
if hashes:
request_json_body['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
if opts.openai_force_no_hashes:
request_json_body['stop'].append('### ')
else:
request_json_body['stop'].extend(['\nuser:', '\nassistant:'])
if request_json_body.get('frequency_penalty', 0) < -2:
request_json_body['frequency_penalty'] = -2
elif request_json_body.get('frequency_penalty', 0) > 2:
request_json_body['frequency_penalty'] = 2
if mode == 'vllm' and request_json_body.get('top_p') == 0:
request_json_body['top_p'] = 0.01
return request_json_body
def format_oai_err(err_msg):
return jsonify({
"error": {
"message": err_msg,
"type": "invalid_request_error",
"param": None,
"code": None
}
}), 400
def validate_oai(parameters):
if parameters['temperature'] > 2:
return format_oai_err(f"{parameters['temperature']} is greater than the maximum of 2 - 'temperature'")
if parameters['temperature'] < 0:
return format_oai_err(f"{parameters['temperature']} less than the minimum of 0 - 'temperature'")
if parameters.get('top_p', 1) > 2:
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
if parameters.get('top_p', 1) < 0:
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
if parameters.get('presence_penalty', 1) > 2:
return format_oai_err(f"{parameters['presence_penalty']} is greater than the maximum of 2 - 'presence_penalty'")
if parameters.get('presence_penalty', 1) < -2:
return format_oai_err(f"{parameters['presence_penalty']} less than the minimum of -2 - 'presence_penalty'")
if parameters.get('top_p', 1) > 2:
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
if parameters.get('top_p', 1) < 0:
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
if parameters.get('top_p', 1) > 2:
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
if parameters.get('top_p', 1) < 0:
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")

View File

@ -2,73 +2,24 @@ import concurrent.futures
import re import re
import secrets import secrets
import string import string
import time
import traceback import traceback
from typing import Dict, List from typing import Dict, List
import tiktoken import tiktoken
from flask import jsonify, make_response
import llm_server
from llm_server import opts from llm_server import opts
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
from llm_server.custom_redis import redis
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line. ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line. ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
def build_openai_response(prompt, response, model=None):
# Seperate the user's prompt from the context
x = prompt.split('### USER:')
if len(x) > 1:
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
# Make sure the bot doesn't put any other instructions in its response
# y = response.split('\n### ')
# if len(y) > 1:
# response = re.sub(r'\n$', '', y[0].strip(' '))
response = re.sub(ANTI_RESPONSE_RE, '', response)
response = re.sub(ANTI_CONTINUATION_RE, '', response)
# TODO: async/await
prompt_tokens = llm_server.llm.get_token_count(prompt)
response_tokens = llm_server.llm.get_token_count(response)
running_model = redis.get('running_model', 'ERROR', dtype=str)
response = make_response(jsonify({
"id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response,
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
}), 200)
stats = redis.get('proxy_stats', dtype=dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response
def generate_oai_string(length=24): def generate_oai_string(length=24):
alphabet = string.ascii_letters + string.digits alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for i in range(length)) return ''.join(secrets.choice(alphabet) for i in range(length))
def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -> List[Dict[str, str]]: def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]:
tokenizer = tiktoken.get_encoding("cl100k_base") tokenizer = tiktoken.get_encoding("cl100k_base")
def get_token_count_tiktoken_thread(msg): def get_token_count_tiktoken_thread(msg):
@ -95,13 +46,13 @@ def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -
break break
def get_token_count_thread(msg): def get_token_count_thread(msg):
return get_token_count(msg["content"]) return get_token_count(msg["content"], backend_url)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_thread, prompt)) token_counts = list(executor.map(get_token_count_thread, prompt))
total_tokens = sum(token_counts) total_tokens = sum(token_counts)
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt)) - total_tokens formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
if total_tokens + formatting_tokens > context_token_limit: if total_tokens + formatting_tokens > context_token_limit:
# Start over, but this time calculate the token count using the backend # Start over, but this time calculate the token count using the backend
@ -109,6 +60,40 @@ def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -
token_counts = list(executor.map(get_token_count_thread, prompt)) token_counts = list(executor.map(get_token_count_thread, prompt))
else: else:
break break
return prompt
def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str:
tokenizer = tiktoken.get_encoding("cl100k_base")
def get_token_count_tiktoken_thread(msg):
return len(tokenizer.encode(msg))
token_count = get_token_count_tiktoken_thread(prompt)
# If total tokens exceed the limit, start trimming
if token_count > context_token_limit:
while True:
while token_count > context_token_limit:
# Calculate the index to start removing characters from
remove_index = len(prompt) // 3
while remove_index < len(prompt):
prompt = prompt[:remove_index] + prompt[remove_index + 100:]
token_count = get_token_count_tiktoken_thread(prompt)
if token_count <= context_token_limit or remove_index == len(prompt):
break
def get_token_count_thread(msg):
return get_token_count(msg, backend_url)
token_count = get_token_count_thread(prompt)
if token_count > context_token_limit:
# Start over, but this time calculate the token count using the backend
token_count = get_token_count_thread(prompt)
else:
break
return prompt return prompt

View File

@ -1,4 +1,3 @@
import threading
from typing import Tuple from typing import Tuple
from flask import jsonify from flask import jsonify
@ -35,9 +34,11 @@ class VLLMBackend(LLMBackend):
top_p=parameters.get('top_p', self._default_params['top_p']), top_p=parameters.get('top_p', self._default_params['top_p']),
top_k=top_k, top_k=top_k,
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False, use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
stop=parameters.get('stopping_strings', self._default_params['stop']), stop=list(set(parameters.get('stopping_strings', self._default_params['stop']) or parameters.get('stop', self._default_params['stop']))),
ignore_eos=parameters.get('ban_eos_token', False), ignore_eos=parameters.get('ban_eos_token', False),
max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens']) max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens']) or parameters.get('max_tokens', self._default_params['max_tokens']),
presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']),
frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty'])
) )
except ValueError as e: except ValueError as e:
return None, str(e).strip('.') return None, str(e).strip('.')

View File

@ -1,5 +1,4 @@
import json import json
import threading
import time import time
import traceback import traceback
@ -10,11 +9,10 @@ from . import openai_bp
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler from ..openai_request_handler import OpenAIRequestHandler
from ... import opts from ... import opts
from ...cluster.backend import get_a_cluster_backend
from ...database.database import log_prompt from ...database.database import log_prompt
from ...llm.generator import generator from ...llm.generator import generator
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt from ...llm.openai.oai_to_vllm import oai_to_vllm
from ...llm.vllm import tokenize from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
# TODO: add rate-limit headers? # TODO: add rate-limit headers?
@ -25,32 +23,46 @@ def openai_chat_completions():
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'): if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
else: else:
handler = OpenAIRequestHandler(request, request_json_body) handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body)
if request_json_body.get('stream'):
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends
raise NotImplementedError
if not request_json_body.get('stream'):
try:
return handler.handle_request()
except Exception:
traceback.print_exc()
return 'Internal server error', 500
else:
if not opts.enable_streaming: if not opts.enable_streaming:
# TODO: return a proper OAI error message # TODO: return a proper OAI error message
return 'disabled', 401 return 'disabled', 401
if opts.mode != 'vllm': if opts.openai_silent_trim:
# TODO: implement other backends handler.request_json_body['messages'] = trim_messages_to_fit(request_json_body['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
raise NotImplementedError
response_status_code = 0 response_status_code = 0
start_time = time.time() start_time = time.time()
request_valid, invalid_response = handler.validate_request() request_valid, invalid_response = handler.validate_request()
if not request_valid: if not request_valid:
# TODO: simulate OAI here return invalid_response
raise Exception('TODO: simulate OAI here')
else: else:
handler.prompt = transform_messages_to_prompt(request_json_body['messages']) if opts.openai_silent_trim:
oai_messages = trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
else:
oai_messages = handler.request.json['messages']
handler.prompt = transform_messages_to_prompt(oai_messages)
handler.parameters = oai_to_vllm(handler.parameters, hashes=True, mode=handler.cluster_backend_info['mode'])
msg_to_backend = { msg_to_backend = {
**handler.parameters, **handler.parameters,
'prompt': handler.prompt, 'prompt': handler.prompt,
'stream': True, 'stream': True,
} }
try: try:
cluster_backend = get_a_cluster_backend() response = generator(msg_to_backend, handler.backend_url)
response = generator(msg_to_backend, cluster_backend)
r_headers = dict(request.headers) r_headers = dict(request.headers)
r_url = request.url r_url = request.url
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
@ -94,22 +106,20 @@ def openai_chat_completions():
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
def background_task(): log_prompt(
generated_tokens = tokenize(generated_text) handler.client_ip,
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=generated_tokens) handler.token,
handler.prompt,
# TODO: use async/await instead of threads generated_text,
thread = threading.Thread(target=background_task) elapsed_time,
thread.start() handler.parameters,
thread.join() r_headers,
response_status_code,
r_url,
handler.backend_url,
)
return Response(generate(), mimetype='text/event-stream') return Response(generate(), mimetype='text/event-stream')
except: except:
# TODO: simulate OAI here # TODO: simulate OAI here
raise Exception raise Exception
else:
try:
return handler.handle_request()
except Exception:
traceback.print_exc()
return 'Internal server error', 500

View File

@ -1,15 +1,19 @@
import time import time
import traceback import traceback
from flask import jsonify, make_response, request import simplejson as json
from flask import Response, jsonify, request
from . import openai_bp
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from . import openai_bp
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
from ... import opts from ... import opts
from ...database.database import log_prompt
from ...llm import get_token_count from ...llm import get_token_count
from ...llm.openai.transform import generate_oai_string from ...llm.generator import generator
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
# TODO: add rate-limit headers? # TODO: add rate-limit headers?
@ -21,40 +25,137 @@ def openai_completions():
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else: else:
try: try:
response, status_code = OobaRequestHandler(request).handle_request() handler = OobaRequestHandler(incoming_request=request)
if status_code != 200:
return status_code
output = response.json['results'][0]['text']
# TODO: async/await if handler.cluster_backend_info['mode'] != 'vllm':
prompt_tokens = get_token_count(request_json_body['prompt']) # TODO: implement other backends
response_tokens = get_token_count(output) raise NotImplementedError
running_model = redis.get('running_model', 'ERROR', dtype=str)
response = make_response(jsonify({ invalid_oai_err_msg = validate_oai(handler.request_json_body)
"id": f"cmpl-{generate_oai_string(30)}", if invalid_oai_err_msg:
"object": "text_completion", return invalid_oai_err_msg
"created": int(time.time()), handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
"choices": [ # Convert parameters to the selected backend type
{ if opts.openai_silent_trim:
"text": output, handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
"index": 0, else:
"logprobs": None, # The handle_request() call below will load the prompt so we don't have
"finish_reason": None # to do anything else here.
pass
if not request_json_body.get('stream'):
response, status_code = handler.handle_request()
if status_code != 200:
return status_code
output = response.json['results'][0]['text']
# TODO: async/await
prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url)
response_tokens = get_token_count(output, handler.backend_url)
running_model = redis.get('running_model', 'ERROR', dtype=str)
response = jsonify({
"id": f"cmpl-{generate_oai_string(30)}",
"object": "text_completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
"choices": [
{
"text": output,
"index": 0,
"logprobs": None,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
} }
], })
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
}), 200)
stats = redis.get('proxy_stats', dtype=dict) stats = redis.get('proxy_stats', dtype=dict)
if stats: if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response return response, 200
else:
if not opts.enable_streaming:
# TODO: return a proper OAI error message
return 'disabled', 401
response_status_code = 0
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
# TODO: simulate OAI here
raise Exception('TODO: simulate OAI here')
else:
handler.prompt = handler.request_json_body['prompt']
msg_to_backend = {
**handler.parameters,
'prompt': handler.prompt,
'stream': True,
}
response = generator(msg_to_backend, handler.backend_url)
r_headers = dict(request.headers)
r_url = request.url
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30)
def generate():
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
data = {
"id": f"chatcmpl-{oai_string}",
"object": "text_completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": new
},
"finish_reason": None
}
]
}
yield f'data: {json.dumps(data)}\n\n'
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(
handler.client_ip,
handler.token,
handler.prompt,
generated_text,
elapsed_time,
handler.parameters,
r_headers,
response_status_code,
r_url,
handler.backend_url,
)
return Response(generate(), mimetype='text/event-stream')
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
return 'Internal Server Error', 500 return 'Internal Server Error', 500

View File

@ -3,24 +3,24 @@ import traceback
import requests import requests
from flask import jsonify from flask import jsonify
from . import openai_bp
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis
from . import openai_bp
from ..stats import server_start_time from ..stats import server_start_time
from ... import opts from ... import opts
from ...cluster.backend import get_a_cluster_backend from ...cluster.backend import get_a_cluster_backend
from ...cluster.cluster_config import cluster_config
from ...helpers import jsonify_pretty from ...helpers import jsonify_pretty
from ...llm.info import get_running_model from ...llm.openai.transform import generate_oai_string
@openai_bp.route('/models', methods=['GET']) @openai_bp.route('/models', methods=['GET'])
@flask_cache.cached(timeout=60, query_string=True) @flask_cache.cached(timeout=60, query_string=True)
def openai_list_models(): def openai_list_models():
model, error = get_running_model() model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model')
if not model: if not model_name:
response = jsonify({ response = jsonify({
'code': 502, 'code': 502,
'msg': 'failed to reach backend', 'msg': 'failed to reach backend',
'type': error.__class__.__name__
}), 500 # return 500 so Cloudflare doesn't intercept us }), 500 # return 500 so Cloudflare doesn't intercept us
else: else:
running_model = redis.get('running_model', 'ERROR', dtype=str) running_model = redis.get('running_model', 'ERROR', dtype=str)
@ -65,7 +65,14 @@ def fetch_openai_models():
if opts.openai_api_key: if opts.openai_api_key:
try: try:
response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10) response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10)
return response.json()['data'] j = response.json()['data']
# The "modelperm" string appears to be user-specific, so we'll
# randomize it just to be safe.
for model in range(len(j)):
for p in range(len(j[model]['permission'])):
j[model]['permission'][p]['id'] = f'modelperm-{generate_oai_string(24)}'
return j
except: except:
traceback.print_exc() traceback.print_exc()
return [] return []

View File

@ -17,7 +17,7 @@ def openai_organizations():
"id": f"org-{generate_oai_string(24)}", "id": f"org-{generate_oai_string(24)}",
"created": int(server_start_time.timestamp()), "created": int(server_start_time.timestamp()),
"title": "Personal", "title": "Personal",
"name": "user-abcdefghijklmnopqrstuvwx", "name": f"user-{generate_oai_string(24)}",
"description": "Personal org for bobjoe@0.0.0.0", "description": "Personal org for bobjoe@0.0.0.0",
"personal": True, "personal": True,
"is_default": True, "is_default": True,

View File

@ -1,14 +1,19 @@
import json import json
import re
import time
import traceback import traceback
from typing import Tuple from typing import Tuple
from uuid import uuid4 from uuid import uuid4
import flask import flask
from flask import jsonify from flask import Response, jsonify, make_response
import llm_server
from llm_server import opts from llm_server import opts
from llm_server.custom_redis import redis
from llm_server.database.database import is_api_key_moderated from llm_server.database.database import is_api_key_moderated
from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from llm_server.routes.request_handler import RequestHandler from llm_server.routes.request_handler import RequestHandler
from llm_server.workers.moderator import add_moderation_task, get_results from llm_server.workers.moderator import add_moderation_task, get_results
@ -22,7 +27,7 @@ class OpenAIRequestHandler(RequestHandler):
assert not self.used assert not self.used
if opts.openai_silent_trim: if opts.openai_silent_trim:
oai_messages = trim_prompt_to_fit(self.request.json['messages'], opts.context_size) oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
else: else:
oai_messages = self.request.json['messages'] oai_messages = self.request.json['messages']
@ -51,13 +56,8 @@ class OpenAIRequestHandler(RequestHandler):
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}') print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
print(traceback.format_exc()) print(traceback.format_exc())
# Reconstruct the request JSON with the validated parameters and prompt. # TODO: support Ooba
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE']) self.parameters = oai_to_vllm(self.parameters, hashes=True, mode=self.cluster_backend_info['mode'])
if opts.openai_force_no_hashes:
self.parameters['stop'].append('### ')
if opts.mode == 'vllm' and self.request_json_body.get('top_p') == 0:
self.request_json_body['top_p'] = 0.01
llm_request = {**self.parameters, 'prompt': self.prompt} llm_request = {**self.parameters, 'prompt': self.prompt}
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request) (success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
@ -65,7 +65,7 @@ class OpenAIRequestHandler(RequestHandler):
model = self.request_json_body.get('model') model = self.request_json_body.get('model')
if success: if success:
return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code return self.build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
else: else:
return backend_response, backend_response_status_code return backend_response, backend_response_status_code
@ -75,7 +75,6 @@ class OpenAIRequestHandler(RequestHandler):
return 'Ratelimited', 429 return 'Ratelimited', 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
# TODO: return a simulated OpenAI error message
return jsonify({ return jsonify({
"error": { "error": {
"message": "Invalid request, check your parameters and try again.", "message": "Invalid request, check your parameters and try again.",
@ -84,3 +83,52 @@ class OpenAIRequestHandler(RequestHandler):
"code": None "code": None
} }
}), 400 }), 400
def build_openai_response(self, prompt, response, model=None):
# Seperate the user's prompt from the context
x = prompt.split('### USER:')
if len(x) > 1:
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
# Make sure the bot doesn't put any other instructions in its response
response = re.sub(ANTI_RESPONSE_RE, '', response)
response = re.sub(ANTI_CONTINUATION_RE, '', response)
# TODO: async/await
prompt_tokens = llm_server.llm.get_token_count(prompt, self.backend_url)
response_tokens = llm_server.llm.get_token_count(response, self.backend_url)
running_model = redis.get('running_model', 'ERROR', dtype=str)
response = make_response(jsonify({
"id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response,
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
}), 200)
stats = redis.get('proxy_stats', dtype=dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
invalid_oai_err_msg = validate_oai(self.request_json_body)
if invalid_oai_err_msg:
return False, invalid_oai_err_msg
self.request_json_body = oai_to_vllm(self.request_json_body, hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
# If the parameters were invalid, let the superclass deal with it.
return super().validate_request(prompt, do_log)

View File

@ -15,13 +15,13 @@ from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.auth import parse_token from llm_server.routes.auth import parse_token
from llm_server.routes.helpers.http import require_api_key, validate_json from llm_server.routes.helpers.http import require_api_key, validate_json
from llm_server.routes.queue import RedisPriorityQueue, priority_queue from llm_server.routes.queue import priority_queue
DEFAULT_PRIORITY = 9999 DEFAULT_PRIORITY = 9999
class RequestHandler: class RequestHandler:
def __init__(self, incoming_request: flask.Request, selected_model: str, incoming_json: Union[dict, str] = None): def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
self.request = incoming_request self.request = incoming_request
self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true' self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
@ -41,7 +41,7 @@ class RequestHandler:
self.cluster_backend_info = cluster_config.get_backend(self.backend_url) self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
if not self.cluster_backend_info.get('mode'): if not self.cluster_backend_info.get('mode'):
print(self.backend_url, self.cluster_backend_info) print(selected_model, self.backend_url, self.cluster_backend_info)
self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url) self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url)
self.parameters = None self.parameters = None

View File

@ -5,8 +5,6 @@ from flask import jsonify, request
from . import bp from . import bp
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
from ...cluster.backend import get_a_cluster_backend
from ...cluster.cluster_config import cluster_config
@bp.route('/v1/generate', methods=['POST']) @bp.route('/v1/generate', methods=['POST'])

View File

@ -71,6 +71,7 @@ def generate_stats(regen: bool = False):
'model': backend_info['model'], 'model': backend_info['model'],
'mode': backend_info['mode'], 'mode': backend_info['mode'],
'nvidia': backend_info['nvidia'], 'nvidia': backend_info['nvidia'],
'priority': backend_info['priority'],
} }
else: else:
output['backend_info'] = {} output['backend_info'] = {}

View File

@ -84,7 +84,7 @@ def do_stream(ws, model_name):
ws.close() ws.close()
return auth_failure return auth_failure
handler = OobaRequestHandler(request, model_name, request_json_body) handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
generated_text = '' generated_text = ''
input_prompt = request_json_body['prompt'] input_prompt = request_json_body['prompt']
response_status_code = 0 response_status_code = 0

View File

@ -4,10 +4,8 @@ from flask import jsonify, request
from llm_server.custom_redis import flask_cache from llm_server.custom_redis import flask_cache
from . import bp from . import bp
from ..auth import requires_auth
from ... import opts from ... import opts
from ...cluster.backend import get_a_cluster_backend, get_backends, get_backends_from_model, is_valid_model from ...cluster.backend import get_a_cluster_backend, get_backends_from_model, is_valid_model
from ...cluster.cluster_config import cluster_config
@bp.route('/v1/model', methods=['GET']) @bp.route('/v1/model', methods=['GET'])
@ -39,14 +37,3 @@ def get_model(model_name=None):
flask_cache.set(cache_key, response, timeout=60) flask_cache.set(cache_key, response, timeout=60)
return response return response
@bp.route('/backends', methods=['GET'])
@requires_auth
def get_backend():
online, offline = get_backends()
result = {}
for i in online + offline:
info = cluster_config.get_backend(i)
result[info['hash']] = info
return jsonify(result), 200

View File

@ -1,6 +1,11 @@
from flask import jsonify
from llm_server.custom_redis import flask_cache
from . import bp from . import bp
from .generate_stats import generate_stats from .generate_stats import generate_stats
from llm_server.custom_redis import flask_cache from ..auth import requires_auth
from ...cluster.backend import get_backends
from ...cluster.cluster_config import cluster_config
from ...helpers import jsonify_pretty from ...helpers import jsonify_pretty
@ -8,3 +13,14 @@ from ...helpers import jsonify_pretty
@flask_cache.cached(timeout=5, query_string=True) @flask_cache.cached(timeout=5, query_string=True)
def get_stats(): def get_stats():
return jsonify_pretty(generate_stats()) return jsonify_pretty(generate_stats())
@bp.route('/backends', methods=['GET'])
@requires_auth
def get_backend():
online, offline = get_backends()
result = {}
for i in online + offline:
info = cluster_config.get_backend(i)
result[info['hash']] = info
return jsonify(result), 200

View File

@ -24,11 +24,13 @@ from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp from llm_server.routes.v1 import bp
from llm_server.sock import init_socketio from llm_server.sock import init_socketio
# TODO: add a way to cancel VLLM gens. Maybe use websockets? # TODO: make sure openai_moderation_enabled works on websockets, completions, and chat completions
# TODO: need to update opts. for workers
# TODO: add a healthcheck to VLLM
# Lower priority # Lower priority
# TODO: support logit_bias on OpenAI and Ooba endpoints.
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
# TODO: validate openai_silent_trim works as expected and only when enabled
# TODO: rewrite config storage. Store in redis so we can reload it.
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation. # TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
# TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens # TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens
# TODO: the estiamted wait time lags behind the stats # TODO: the estiamted wait time lags behind the stats