minor changes, add admin token auth system, add route to get backend info

This commit is contained in:
Cyberes 2023-09-24 15:54:35 -06:00
parent 2678102153
commit 8d6b2ce49c
11 changed files with 73 additions and 21 deletions

View File

@ -2,7 +2,7 @@
_An HTTP API to serve local LLM Models._
The purpose of this server is to abstract your LLM backend from your frontend API. This enables you to make changes to (or even switch) your backend without affecting your clients.
The purpose of this server is to abstract your LLM backend from your frontend API. This enables you to switch your backend while providing a stable frontend clients.
### Install
@ -43,14 +43,13 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database.
### Use
**DO NOT** lose your database. It's used for calculating the estimated wait time based on average TPS and response tokens and if you lose those stats your numbers will be inaccurate until the database fills back up again. If you change GPUs, you
should probably clear the `generation_time` time column in the `prompts` table.
### To Do
- Implement streaming
- Add `huggingface/text-generation-inference`
- Convince Oobabooga to implement concurrent generation
- Make sure stats work when starting from an empty database
- Make sure we're correctly canceling requests when the client cancels
- Implement auth and tokens on the websocket endpoint. Maybe add something to the instruct prompt and the remove it before proxying??
- [x] Implement streaming
- [ ] Bring streaming endpoint up to the level of the blocking endpoint
- [x] Add VLLM support
- [ ] Make sure stats work when starting from an empty database
- [ ] Make sure we're correctly canceling requests when the client cancels
- [ ] Make sure the OpenAI endpoint works as expected

View File

@ -23,6 +23,7 @@ config_default_vars = {
'expose_openai_system_prompt': True,
'openai_system_prompt': """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""",
'http_host': None,
'admin_token': None,
}
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']

View File

@ -28,3 +28,4 @@ enable_streaming = True
openai_api_key = None
backend_request_timeout = 30
backend_generate_request_timeout = 95
admin_token = None

39
llm_server/routes/auth.py Normal file
View File

@ -0,0 +1,39 @@
from functools import wraps
import basicauth
from flask import Response, request
from llm_server import opts
def check_auth(token):
password = None
if token.startswith('Basic '):
try:
_, password = basicauth.decode(token)
except:
return False
elif token.startswith('Bearer '):
password = token.split('Bearer ', maxsplit=1)
del password[0]
password = ''.join(password)
print(password, token)
return password == opts.admin_token
def authenticate():
"""Sends a 401 response that enables basic auth"""
return Response(
'AUTHENTICATION REQUIRED', 401,
{'WWW-Authenticate': 'Basic realm="Login Required"'})
def requires_auth(f):
@wraps(f)
def decorated(*args, **kwargs):
auth = request.headers.get('Authorization')
if not auth or not check_auth(auth):
return authenticate()
return f(*args, **kwargs)
return decorated

View File

@ -3,7 +3,7 @@ from llm_server.routes.cache import redis
def format_sillytavern_err(msg: str, level: str = 'info'):
http_host = redis.get('http_host')
http_host = redis.get('http_host', str)
return f"""```
=== MESSAGE FROM LLM MIDDLEWARE AT {http_host} ===
-> {level.upper()} <-

View File

@ -35,7 +35,7 @@ def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, act
# TODO: have routes/__init__.py point to the latest API version generate_stats()
@cache.memoize(timeout=10)
@cache.memoize(timeout=20)
def generate_stats():
model_name, error = get_running_model() # will return False when the fetch fails
if isinstance(model_name, bool):

View File

@ -102,7 +102,7 @@ def stream(ws):
generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
except:
generated_text = generated_text + '\n\n' + format_sillytavern_err('encountered error while streaming', 'error')
generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error')
generated_tokens = tokenize(generated_text)
traceback.print_exc()
ws.send(json.dumps({

View File

@ -3,6 +3,7 @@ import time
from flask import jsonify, request
from . import bp
from ..auth import requires_auth
from ..cache import cache
from ... import opts
from ...llm.info import get_running_model
@ -40,3 +41,9 @@ def get_model():
cache.set(cache_key, response, timeout=60)
return response
@bp.route('/backend', methods=['GET'])
@requires_auth
def get_backend():
return jsonify({'backend': opts.backend_url, 'mode': opts.mode}), 200

View File

@ -67,8 +67,5 @@ class MainBackgroundThread(Thread):
def cache_stats():
while True:
# If opts.base_client_api is null that means no one has visited the site yet
# and the base_client_api hasn't been set. Do nothing until then.
if redis.get('base_client_api'):
x = generate_stats()
time.sleep(5)

View File

@ -13,7 +13,10 @@ auto_gptq
uvicorn~=0.23.2
fastapi~=0.103.1
torch~=2.0.1
urllib3
urllib3~=2.0.4
PyMySQL~=1.1.0
DBUtils~=3.0.3
simplejson
simplejson~=3.19.1
setuptools~=65.5.1
websockets~=11.0.3
basicauth~=1.0.0

View File

@ -62,6 +62,7 @@ if config['mode'] not in ['oobabooga', 'vllm']:
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
# TODO: this is a MESS
opts.mode = config['mode']
opts.auth_required = config['auth_required']
opts.log_prompts = config['log_prompts']
@ -83,6 +84,7 @@ opts.openai_system_prompt = config['openai_system_prompt']
opts.expose_openai_system_prompt = config['expose_openai_system_prompt']
opts.enable_streaming = config['enable_streaming']
opts.openai_api_key = config['openai_api_key']
opts.admin_token = config['admin_token']
if config['http_host']:
redis.set('http_host', config['http_host'])
@ -113,6 +115,10 @@ elif opts.mode == 'vllm':
else:
raise Exception
app = Flask(__name__)
cache.init_app(app)
cache.clear() # clear redis cache
# Start background processes
start_workers(opts.concurrent_gens)
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
@ -121,9 +127,6 @@ process_avg_gen_time_background_thread.start()
MainBackgroundThread().start()
SemaphoreCheckerThread().start()
app = Flask(__name__)
cache.init_app(app)
cache.clear() # clear redis cache
init_socketio(app)
app.register_blueprint(bp, url_prefix='/api/v1/')
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
@ -195,6 +198,8 @@ def home():
)
# TODO: add authenticated route to get the current backend URL. Add it to /v1/backend
@app.route('/<first>')
@app.route('/<first>/<path:rest>')
def fallback(first=None, rest=None):