update openai endpoints
This commit is contained in:
parent
93d19fb95b
commit
2a3ff7e21e
|
@ -43,6 +43,8 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database.
|
|||
|
||||
### Use
|
||||
|
||||
If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis.
|
||||
|
||||
Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app`
|
||||
|
||||
|
||||
|
|
|
@ -60,7 +60,8 @@ def get_model_choices(regen: bool = False):
|
|||
if len(context_size):
|
||||
model_choices[model]['context_size'] = min(context_size)
|
||||
|
||||
model_choices = dict(sorted(model_choices.items()))
|
||||
# Python wants to sort lowercase vs. uppercase letters differently.
|
||||
model_choices = dict(sorted(model_choices.items(), key=lambda item: item[0].upper()))
|
||||
|
||||
default_backend = get_a_cluster_backend()
|
||||
default_backend_dict = {}
|
||||
|
|
|
@ -2,7 +2,6 @@ from typing import Tuple, Union
|
|||
|
||||
import flask
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.llm import get_token_count
|
||||
|
@ -36,6 +35,8 @@ class LLMBackend:
|
|||
"""
|
||||
If a backend needs to do other checks not related to the prompt or parameters.
|
||||
Default is no extra checks preformed.
|
||||
:param request:
|
||||
:param prompt:
|
||||
:param parameters:
|
||||
:return:
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
from flask import jsonify
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
|
||||
def oai_to_vllm(request_json_body, hashes: bool, mode):
|
||||
if not request_json_body.get('stop'):
|
||||
request_json_body['stop'] = []
|
||||
|
||||
if hashes:
|
||||
request_json_body['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
|
||||
if opts.openai_force_no_hashes:
|
||||
request_json_body['stop'].append('### ')
|
||||
else:
|
||||
request_json_body['stop'].extend(['\nuser:', '\nassistant:'])
|
||||
|
||||
if request_json_body.get('frequency_penalty', 0) < -2:
|
||||
request_json_body['frequency_penalty'] = -2
|
||||
elif request_json_body.get('frequency_penalty', 0) > 2:
|
||||
request_json_body['frequency_penalty'] = 2
|
||||
|
||||
if mode == 'vllm' and request_json_body.get('top_p') == 0:
|
||||
request_json_body['top_p'] = 0.01
|
||||
|
||||
return request_json_body
|
||||
|
||||
|
||||
def format_oai_err(err_msg):
|
||||
return jsonify({
|
||||
"error": {
|
||||
"message": err_msg,
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None
|
||||
}
|
||||
}), 400
|
||||
|
||||
|
||||
def validate_oai(parameters):
|
||||
if parameters['temperature'] > 2:
|
||||
return format_oai_err(f"{parameters['temperature']} is greater than the maximum of 2 - 'temperature'")
|
||||
if parameters['temperature'] < 0:
|
||||
return format_oai_err(f"{parameters['temperature']} less than the minimum of 0 - 'temperature'")
|
||||
|
||||
if parameters.get('top_p', 1) > 2:
|
||||
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
|
||||
if parameters.get('top_p', 1) < 0:
|
||||
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
|
||||
|
||||
if parameters.get('presence_penalty', 1) > 2:
|
||||
return format_oai_err(f"{parameters['presence_penalty']} is greater than the maximum of 2 - 'presence_penalty'")
|
||||
if parameters.get('presence_penalty', 1) < -2:
|
||||
return format_oai_err(f"{parameters['presence_penalty']} less than the minimum of -2 - 'presence_penalty'")
|
||||
|
||||
if parameters.get('top_p', 1) > 2:
|
||||
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
|
||||
if parameters.get('top_p', 1) < 0:
|
||||
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
|
||||
|
||||
if parameters.get('top_p', 1) > 2:
|
||||
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
|
||||
if parameters.get('top_p', 1) < 0:
|
||||
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
|
|
@ -2,73 +2,24 @@ import concurrent.futures
|
|||
import re
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
import traceback
|
||||
from typing import Dict, List
|
||||
|
||||
import tiktoken
|
||||
from flask import jsonify, make_response
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.llm import get_token_count
|
||||
from llm_server.custom_redis import redis
|
||||
|
||||
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
|
||||
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
|
||||
|
||||
|
||||
def build_openai_response(prompt, response, model=None):
|
||||
# Seperate the user's prompt from the context
|
||||
x = prompt.split('### USER:')
|
||||
if len(x) > 1:
|
||||
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
|
||||
|
||||
# Make sure the bot doesn't put any other instructions in its response
|
||||
# y = response.split('\n### ')
|
||||
# if len(y) > 1:
|
||||
# response = re.sub(r'\n$', '', y[0].strip(' '))
|
||||
response = re.sub(ANTI_RESPONSE_RE, '', response)
|
||||
response = re.sub(ANTI_CONTINUATION_RE, '', response)
|
||||
|
||||
# TODO: async/await
|
||||
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
||||
response_tokens = llm_server.llm.get_token_count(response)
|
||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||
|
||||
response = make_response(jsonify({
|
||||
"id": f"chatcmpl-{generate_oai_string(30)}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": running_model if opts.openai_expose_our_model else model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response,
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": response_tokens,
|
||||
"total_tokens": prompt_tokens + response_tokens
|
||||
}
|
||||
}), 200)
|
||||
|
||||
stats = redis.get('proxy_stats', dtype=dict)
|
||||
if stats:
|
||||
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
||||
return response
|
||||
|
||||
|
||||
def generate_oai_string(length=24):
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
return ''.join(secrets.choice(alphabet) for i in range(length))
|
||||
|
||||
|
||||
def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -> List[Dict[str, str]]:
|
||||
def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]:
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def get_token_count_tiktoken_thread(msg):
|
||||
|
@ -95,13 +46,13 @@ def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -
|
|||
break
|
||||
|
||||
def get_token_count_thread(msg):
|
||||
return get_token_count(msg["content"])
|
||||
return get_token_count(msg["content"], backend_url)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
token_counts = list(executor.map(get_token_count_thread, prompt))
|
||||
|
||||
total_tokens = sum(token_counts)
|
||||
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt)) - total_tokens
|
||||
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
|
||||
|
||||
if total_tokens + formatting_tokens > context_token_limit:
|
||||
# Start over, but this time calculate the token count using the backend
|
||||
|
@ -109,6 +60,40 @@ def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -
|
|||
token_counts = list(executor.map(get_token_count_thread, prompt))
|
||||
else:
|
||||
break
|
||||
return prompt
|
||||
|
||||
|
||||
def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str:
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def get_token_count_tiktoken_thread(msg):
|
||||
return len(tokenizer.encode(msg))
|
||||
|
||||
token_count = get_token_count_tiktoken_thread(prompt)
|
||||
|
||||
# If total tokens exceed the limit, start trimming
|
||||
if token_count > context_token_limit:
|
||||
while True:
|
||||
while token_count > context_token_limit:
|
||||
# Calculate the index to start removing characters from
|
||||
remove_index = len(prompt) // 3
|
||||
|
||||
while remove_index < len(prompt):
|
||||
prompt = prompt[:remove_index] + prompt[remove_index + 100:]
|
||||
token_count = get_token_count_tiktoken_thread(prompt)
|
||||
if token_count <= context_token_limit or remove_index == len(prompt):
|
||||
break
|
||||
|
||||
def get_token_count_thread(msg):
|
||||
return get_token_count(msg, backend_url)
|
||||
|
||||
token_count = get_token_count_thread(prompt)
|
||||
|
||||
if token_count > context_token_limit:
|
||||
# Start over, but this time calculate the token count using the backend
|
||||
token_count = get_token_count_thread(prompt)
|
||||
else:
|
||||
break
|
||||
|
||||
return prompt
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import threading
|
||||
from typing import Tuple
|
||||
|
||||
from flask import jsonify
|
||||
|
@ -35,9 +34,11 @@ class VLLMBackend(LLMBackend):
|
|||
top_p=parameters.get('top_p', self._default_params['top_p']),
|
||||
top_k=top_k,
|
||||
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
|
||||
stop=parameters.get('stopping_strings', self._default_params['stop']),
|
||||
stop=list(set(parameters.get('stopping_strings', self._default_params['stop']) or parameters.get('stop', self._default_params['stop']))),
|
||||
ignore_eos=parameters.get('ban_eos_token', False),
|
||||
max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens'])
|
||||
max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens']) or parameters.get('max_tokens', self._default_params['max_tokens']),
|
||||
presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']),
|
||||
frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty'])
|
||||
)
|
||||
except ValueError as e:
|
||||
return None, str(e).strip('.')
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
|
@ -10,11 +9,10 @@ from . import openai_bp
|
|||
from ..helpers.http import validate_json
|
||||
from ..openai_request_handler import OpenAIRequestHandler
|
||||
from ... import opts
|
||||
from ...cluster.backend import get_a_cluster_backend
|
||||
from ...database.database import log_prompt
|
||||
from ...llm.generator import generator
|
||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt
|
||||
from ...llm.vllm import tokenize
|
||||
from ...llm.openai.oai_to_vllm import oai_to_vllm
|
||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||
|
||||
|
||||
# TODO: add rate-limit headers?
|
||||
|
@ -25,32 +23,46 @@ def openai_chat_completions():
|
|||
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
|
||||
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
||||
else:
|
||||
handler = OpenAIRequestHandler(request, request_json_body)
|
||||
if request_json_body.get('stream'):
|
||||
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body)
|
||||
|
||||
if handler.cluster_backend_info['mode'] != 'vllm':
|
||||
# TODO: implement other backends
|
||||
raise NotImplementedError
|
||||
|
||||
if not request_json_body.get('stream'):
|
||||
try:
|
||||
return handler.handle_request()
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return 'Internal server error', 500
|
||||
else:
|
||||
if not opts.enable_streaming:
|
||||
# TODO: return a proper OAI error message
|
||||
return 'disabled', 401
|
||||
|
||||
if opts.mode != 'vllm':
|
||||
# TODO: implement other backends
|
||||
raise NotImplementedError
|
||||
if opts.openai_silent_trim:
|
||||
handler.request_json_body['messages'] = trim_messages_to_fit(request_json_body['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
||||
|
||||
response_status_code = 0
|
||||
start_time = time.time()
|
||||
request_valid, invalid_response = handler.validate_request()
|
||||
if not request_valid:
|
||||
# TODO: simulate OAI here
|
||||
raise Exception('TODO: simulate OAI here')
|
||||
return invalid_response
|
||||
else:
|
||||
handler.prompt = transform_messages_to_prompt(request_json_body['messages'])
|
||||
if opts.openai_silent_trim:
|
||||
oai_messages = trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
||||
else:
|
||||
oai_messages = handler.request.json['messages']
|
||||
|
||||
handler.prompt = transform_messages_to_prompt(oai_messages)
|
||||
handler.parameters = oai_to_vllm(handler.parameters, hashes=True, mode=handler.cluster_backend_info['mode'])
|
||||
msg_to_backend = {
|
||||
**handler.parameters,
|
||||
'prompt': handler.prompt,
|
||||
'stream': True,
|
||||
}
|
||||
try:
|
||||
cluster_backend = get_a_cluster_backend()
|
||||
response = generator(msg_to_backend, cluster_backend)
|
||||
response = generator(msg_to_backend, handler.backend_url)
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
|
||||
|
@ -94,22 +106,20 @@ def openai_chat_completions():
|
|||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
def background_task():
|
||||
generated_tokens = tokenize(generated_text)
|
||||
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=generated_tokens)
|
||||
|
||||
# TODO: use async/await instead of threads
|
||||
thread = threading.Thread(target=background_task)
|
||||
thread.start()
|
||||
thread.join()
|
||||
log_prompt(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
generated_text,
|
||||
elapsed_time,
|
||||
handler.parameters,
|
||||
r_headers,
|
||||
response_status_code,
|
||||
r_url,
|
||||
handler.backend_url,
|
||||
)
|
||||
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
except:
|
||||
# TODO: simulate OAI here
|
||||
raise Exception
|
||||
else:
|
||||
try:
|
||||
return handler.handle_request()
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return 'Internal server error', 500
|
||||
|
|
|
@ -1,15 +1,19 @@
|
|||
import time
|
||||
import traceback
|
||||
|
||||
from flask import jsonify, make_response, request
|
||||
import simplejson as json
|
||||
from flask import Response, jsonify, request
|
||||
|
||||
from . import openai_bp
|
||||
from llm_server.custom_redis import redis
|
||||
from . import openai_bp
|
||||
from ..helpers.http import validate_json
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ... import opts
|
||||
from ...database.database import log_prompt
|
||||
from ...llm import get_token_count
|
||||
from ...llm.openai.transform import generate_oai_string
|
||||
from ...llm.generator import generator
|
||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
|
||||
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
|
||||
|
||||
|
||||
# TODO: add rate-limit headers?
|
||||
|
@ -21,40 +25,137 @@ def openai_completions():
|
|||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
else:
|
||||
try:
|
||||
response, status_code = OobaRequestHandler(request).handle_request()
|
||||
if status_code != 200:
|
||||
return status_code
|
||||
output = response.json['results'][0]['text']
|
||||
handler = OobaRequestHandler(incoming_request=request)
|
||||
|
||||
# TODO: async/await
|
||||
prompt_tokens = get_token_count(request_json_body['prompt'])
|
||||
response_tokens = get_token_count(output)
|
||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||
if handler.cluster_backend_info['mode'] != 'vllm':
|
||||
# TODO: implement other backends
|
||||
raise NotImplementedError
|
||||
|
||||
response = make_response(jsonify({
|
||||
"id": f"cmpl-{generate_oai_string(30)}",
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
|
||||
"choices": [
|
||||
{
|
||||
"text": output,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": None
|
||||
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
||||
if invalid_oai_err_msg:
|
||||
return invalid_oai_err_msg
|
||||
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
|
||||
|
||||
# Convert parameters to the selected backend type
|
||||
if opts.openai_silent_trim:
|
||||
handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
||||
else:
|
||||
# The handle_request() call below will load the prompt so we don't have
|
||||
# to do anything else here.
|
||||
pass
|
||||
|
||||
if not request_json_body.get('stream'):
|
||||
response, status_code = handler.handle_request()
|
||||
if status_code != 200:
|
||||
return status_code
|
||||
output = response.json['results'][0]['text']
|
||||
|
||||
# TODO: async/await
|
||||
prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url)
|
||||
response_tokens = get_token_count(output, handler.backend_url)
|
||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||
|
||||
response = jsonify({
|
||||
"id": f"cmpl-{generate_oai_string(30)}",
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
|
||||
"choices": [
|
||||
{
|
||||
"text": output,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": response_tokens,
|
||||
"total_tokens": prompt_tokens + response_tokens
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": response_tokens,
|
||||
"total_tokens": prompt_tokens + response_tokens
|
||||
}
|
||||
}), 200)
|
||||
})
|
||||
|
||||
stats = redis.get('proxy_stats', dtype=dict)
|
||||
if stats:
|
||||
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
||||
return response
|
||||
stats = redis.get('proxy_stats', dtype=dict)
|
||||
if stats:
|
||||
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
||||
return response, 200
|
||||
else:
|
||||
if not opts.enable_streaming:
|
||||
# TODO: return a proper OAI error message
|
||||
return 'disabled', 401
|
||||
|
||||
response_status_code = 0
|
||||
start_time = time.time()
|
||||
|
||||
request_valid, invalid_response = handler.validate_request()
|
||||
if not request_valid:
|
||||
# TODO: simulate OAI here
|
||||
raise Exception('TODO: simulate OAI here')
|
||||
else:
|
||||
handler.prompt = handler.request_json_body['prompt']
|
||||
msg_to_backend = {
|
||||
**handler.parameters,
|
||||
'prompt': handler.prompt,
|
||||
'stream': True,
|
||||
}
|
||||
response = generator(msg_to_backend, handler.backend_url)
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
|
||||
oai_string = generate_oai_string(30)
|
||||
|
||||
def generate():
|
||||
generated_text = ''
|
||||
partial_response = b''
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
||||
generated_text = generated_text + new
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
|
||||
data = {
|
||||
"id": f"chatcmpl-{oai_string}",
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": new
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f'data: {json.dumps(data)}\n\n'
|
||||
|
||||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
log_prompt(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
generated_text,
|
||||
elapsed_time,
|
||||
handler.parameters,
|
||||
r_headers,
|
||||
response_status_code,
|
||||
r_url,
|
||||
handler.backend_url,
|
||||
)
|
||||
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return 'Internal Server Error', 500
|
||||
|
|
|
@ -3,24 +3,24 @@ import traceback
|
|||
import requests
|
||||
from flask import jsonify
|
||||
|
||||
from . import openai_bp
|
||||
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis
|
||||
from . import openai_bp
|
||||
from ..stats import server_start_time
|
||||
from ... import opts
|
||||
from ...cluster.backend import get_a_cluster_backend
|
||||
from ...cluster.cluster_config import cluster_config
|
||||
from ...helpers import jsonify_pretty
|
||||
from ...llm.info import get_running_model
|
||||
from ...llm.openai.transform import generate_oai_string
|
||||
|
||||
|
||||
@openai_bp.route('/models', methods=['GET'])
|
||||
@flask_cache.cached(timeout=60, query_string=True)
|
||||
def openai_list_models():
|
||||
model, error = get_running_model()
|
||||
if not model:
|
||||
model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model')
|
||||
if not model_name:
|
||||
response = jsonify({
|
||||
'code': 502,
|
||||
'msg': 'failed to reach backend',
|
||||
'type': error.__class__.__name__
|
||||
}), 500 # return 500 so Cloudflare doesn't intercept us
|
||||
else:
|
||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||
|
@ -65,7 +65,14 @@ def fetch_openai_models():
|
|||
if opts.openai_api_key:
|
||||
try:
|
||||
response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10)
|
||||
return response.json()['data']
|
||||
j = response.json()['data']
|
||||
|
||||
# The "modelperm" string appears to be user-specific, so we'll
|
||||
# randomize it just to be safe.
|
||||
for model in range(len(j)):
|
||||
for p in range(len(j[model]['permission'])):
|
||||
j[model]['permission'][p]['id'] = f'modelperm-{generate_oai_string(24)}'
|
||||
return j
|
||||
except:
|
||||
traceback.print_exc()
|
||||
return []
|
||||
|
|
|
@ -17,7 +17,7 @@ def openai_organizations():
|
|||
"id": f"org-{generate_oai_string(24)}",
|
||||
"created": int(server_start_time.timestamp()),
|
||||
"title": "Personal",
|
||||
"name": "user-abcdefghijklmnopqrstuvwx",
|
||||
"name": f"user-{generate_oai_string(24)}",
|
||||
"description": "Personal org for bobjoe@0.0.0.0",
|
||||
"personal": True,
|
||||
"is_default": True,
|
||||
|
|
|
@ -1,14 +1,19 @@
|
|||
import json
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
import flask
|
||||
from flask import jsonify
|
||||
from flask import Response, jsonify, make_response
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.database import is_api_key_moderated
|
||||
from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit
|
||||
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
|
||||
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
from llm_server.workers.moderator import add_moderation_task, get_results
|
||||
|
||||
|
@ -22,7 +27,7 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
assert not self.used
|
||||
|
||||
if opts.openai_silent_trim:
|
||||
oai_messages = trim_prompt_to_fit(self.request.json['messages'], opts.context_size)
|
||||
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
|
||||
else:
|
||||
oai_messages = self.request.json['messages']
|
||||
|
||||
|
@ -51,13 +56,8 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
|
||||
print(traceback.format_exc())
|
||||
|
||||
# Reconstruct the request JSON with the validated parameters and prompt.
|
||||
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
|
||||
if opts.openai_force_no_hashes:
|
||||
self.parameters['stop'].append('### ')
|
||||
|
||||
if opts.mode == 'vllm' and self.request_json_body.get('top_p') == 0:
|
||||
self.request_json_body['top_p'] = 0.01
|
||||
# TODO: support Ooba
|
||||
self.parameters = oai_to_vllm(self.parameters, hashes=True, mode=self.cluster_backend_info['mode'])
|
||||
|
||||
llm_request = {**self.parameters, 'prompt': self.prompt}
|
||||
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
||||
|
@ -65,7 +65,7 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
model = self.request_json_body.get('model')
|
||||
|
||||
if success:
|
||||
return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
|
||||
return self.build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
|
||||
else:
|
||||
return backend_response, backend_response_status_code
|
||||
|
||||
|
@ -75,7 +75,6 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
return 'Ratelimited', 429
|
||||
|
||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||
# TODO: return a simulated OpenAI error message
|
||||
return jsonify({
|
||||
"error": {
|
||||
"message": "Invalid request, check your parameters and try again.",
|
||||
|
@ -84,3 +83,52 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
"code": None
|
||||
}
|
||||
}), 400
|
||||
|
||||
def build_openai_response(self, prompt, response, model=None):
|
||||
# Seperate the user's prompt from the context
|
||||
x = prompt.split('### USER:')
|
||||
if len(x) > 1:
|
||||
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
|
||||
|
||||
# Make sure the bot doesn't put any other instructions in its response
|
||||
response = re.sub(ANTI_RESPONSE_RE, '', response)
|
||||
response = re.sub(ANTI_CONTINUATION_RE, '', response)
|
||||
|
||||
# TODO: async/await
|
||||
prompt_tokens = llm_server.llm.get_token_count(prompt, self.backend_url)
|
||||
response_tokens = llm_server.llm.get_token_count(response, self.backend_url)
|
||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||
|
||||
response = make_response(jsonify({
|
||||
"id": f"chatcmpl-{generate_oai_string(30)}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": running_model if opts.openai_expose_our_model else model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response,
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": response_tokens,
|
||||
"total_tokens": prompt_tokens + response_tokens
|
||||
}
|
||||
}), 200)
|
||||
|
||||
stats = redis.get('proxy_stats', dtype=dict)
|
||||
if stats:
|
||||
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
||||
return response
|
||||
|
||||
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
|
||||
invalid_oai_err_msg = validate_oai(self.request_json_body)
|
||||
if invalid_oai_err_msg:
|
||||
return False, invalid_oai_err_msg
|
||||
self.request_json_body = oai_to_vllm(self.request_json_body, hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||
# If the parameters were invalid, let the superclass deal with it.
|
||||
return super().validate_request(prompt, do_log)
|
||||
|
|
|
@ -15,13 +15,13 @@ from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
|||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||
from llm_server.routes.auth import parse_token
|
||||
from llm_server.routes.helpers.http import require_api_key, validate_json
|
||||
from llm_server.routes.queue import RedisPriorityQueue, priority_queue
|
||||
from llm_server.routes.queue import priority_queue
|
||||
|
||||
DEFAULT_PRIORITY = 9999
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
def __init__(self, incoming_request: flask.Request, selected_model: str, incoming_json: Union[dict, str] = None):
|
||||
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
|
||||
self.request = incoming_request
|
||||
self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
|
||||
|
||||
|
@ -41,7 +41,7 @@ class RequestHandler:
|
|||
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
|
||||
|
||||
if not self.cluster_backend_info.get('mode'):
|
||||
print(self.backend_url, self.cluster_backend_info)
|
||||
print(selected_model, self.backend_url, self.cluster_backend_info)
|
||||
|
||||
self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url)
|
||||
self.parameters = None
|
||||
|
|
|
@ -5,8 +5,6 @@ from flask import jsonify, request
|
|||
from . import bp
|
||||
from ..helpers.http import validate_json
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ...cluster.backend import get_a_cluster_backend
|
||||
from ...cluster.cluster_config import cluster_config
|
||||
|
||||
|
||||
@bp.route('/v1/generate', methods=['POST'])
|
||||
|
|
|
@ -71,6 +71,7 @@ def generate_stats(regen: bool = False):
|
|||
'model': backend_info['model'],
|
||||
'mode': backend_info['mode'],
|
||||
'nvidia': backend_info['nvidia'],
|
||||
'priority': backend_info['priority'],
|
||||
}
|
||||
else:
|
||||
output['backend_info'] = {}
|
||||
|
|
|
@ -84,7 +84,7 @@ def do_stream(ws, model_name):
|
|||
ws.close()
|
||||
return auth_failure
|
||||
|
||||
handler = OobaRequestHandler(request, model_name, request_json_body)
|
||||
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
|
||||
generated_text = ''
|
||||
input_prompt = request_json_body['prompt']
|
||||
response_status_code = 0
|
||||
|
|
|
@ -4,10 +4,8 @@ from flask import jsonify, request
|
|||
|
||||
from llm_server.custom_redis import flask_cache
|
||||
from . import bp
|
||||
from ..auth import requires_auth
|
||||
from ... import opts
|
||||
from ...cluster.backend import get_a_cluster_backend, get_backends, get_backends_from_model, is_valid_model
|
||||
from ...cluster.cluster_config import cluster_config
|
||||
from ...cluster.backend import get_a_cluster_backend, get_backends_from_model, is_valid_model
|
||||
|
||||
|
||||
@bp.route('/v1/model', methods=['GET'])
|
||||
|
@ -39,14 +37,3 @@ def get_model(model_name=None):
|
|||
flask_cache.set(cache_key, response, timeout=60)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@bp.route('/backends', methods=['GET'])
|
||||
@requires_auth
|
||||
def get_backend():
|
||||
online, offline = get_backends()
|
||||
result = {}
|
||||
for i in online + offline:
|
||||
info = cluster_config.get_backend(i)
|
||||
result[info['hash']] = info
|
||||
return jsonify(result), 200
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
from flask import jsonify
|
||||
|
||||
from llm_server.custom_redis import flask_cache
|
||||
from . import bp
|
||||
from .generate_stats import generate_stats
|
||||
from llm_server.custom_redis import flask_cache
|
||||
from ..auth import requires_auth
|
||||
from ...cluster.backend import get_backends
|
||||
from ...cluster.cluster_config import cluster_config
|
||||
from ...helpers import jsonify_pretty
|
||||
|
||||
|
||||
|
@ -8,3 +13,14 @@ from ...helpers import jsonify_pretty
|
|||
@flask_cache.cached(timeout=5, query_string=True)
|
||||
def get_stats():
|
||||
return jsonify_pretty(generate_stats())
|
||||
|
||||
|
||||
@bp.route('/backends', methods=['GET'])
|
||||
@requires_auth
|
||||
def get_backend():
|
||||
online, offline = get_backends()
|
||||
result = {}
|
||||
for i in online + offline:
|
||||
info = cluster_config.get_backend(i)
|
||||
result[info['hash']] = info
|
||||
return jsonify(result), 200
|
||||
|
|
|
@ -24,11 +24,13 @@ from llm_server.routes.server_error import handle_server_error
|
|||
from llm_server.routes.v1 import bp
|
||||
from llm_server.sock import init_socketio
|
||||
|
||||
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
|
||||
# TODO: need to update opts. for workers
|
||||
# TODO: add a healthcheck to VLLM
|
||||
# TODO: make sure openai_moderation_enabled works on websockets, completions, and chat completions
|
||||
|
||||
# Lower priority
|
||||
# TODO: support logit_bias on OpenAI and Ooba endpoints.
|
||||
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
|
||||
# TODO: validate openai_silent_trim works as expected and only when enabled
|
||||
# TODO: rewrite config storage. Store in redis so we can reload it.
|
||||
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
|
||||
# TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens
|
||||
# TODO: the estiamted wait time lags behind the stats
|
||||
|
|
Reference in New Issue