minor changes, add admin token auth system, add route to get backend info
This commit is contained in:
parent
2678102153
commit
8d6b2ce49c
17
README.md
17
README.md
|
@ -2,7 +2,7 @@
|
|||
|
||||
_An HTTP API to serve local LLM Models._
|
||||
|
||||
The purpose of this server is to abstract your LLM backend from your frontend API. This enables you to make changes to (or even switch) your backend without affecting your clients.
|
||||
The purpose of this server is to abstract your LLM backend from your frontend API. This enables you to switch your backend while providing a stable frontend clients.
|
||||
|
||||
### Install
|
||||
|
||||
|
@ -43,14 +43,13 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database.
|
|||
|
||||
### Use
|
||||
|
||||
**DO NOT** lose your database. It's used for calculating the estimated wait time based on average TPS and response tokens and if you lose those stats your numbers will be inaccurate until the database fills back up again. If you change GPUs, you
|
||||
should probably clear the `generation_time` time column in the `prompts` table.
|
||||
|
||||
|
||||
### To Do
|
||||
|
||||
- Implement streaming
|
||||
- Add `huggingface/text-generation-inference`
|
||||
- Convince Oobabooga to implement concurrent generation
|
||||
- Make sure stats work when starting from an empty database
|
||||
- Make sure we're correctly canceling requests when the client cancels
|
||||
- Implement auth and tokens on the websocket endpoint. Maybe add something to the instruct prompt and the remove it before proxying??
|
||||
- [x] Implement streaming
|
||||
- [ ] Bring streaming endpoint up to the level of the blocking endpoint
|
||||
- [x] Add VLLM support
|
||||
- [ ] Make sure stats work when starting from an empty database
|
||||
- [ ] Make sure we're correctly canceling requests when the client cancels
|
||||
- [ ] Make sure the OpenAI endpoint works as expected
|
|
@ -23,6 +23,7 @@ config_default_vars = {
|
|||
'expose_openai_system_prompt': True,
|
||||
'openai_system_prompt': """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""",
|
||||
'http_host': None,
|
||||
'admin_token': None,
|
||||
}
|
||||
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
||||
|
||||
|
|
|
@ -28,3 +28,4 @@ enable_streaming = True
|
|||
openai_api_key = None
|
||||
backend_request_timeout = 30
|
||||
backend_generate_request_timeout = 95
|
||||
admin_token = None
|
|
@ -0,0 +1,39 @@
|
|||
from functools import wraps
|
||||
|
||||
import basicauth
|
||||
from flask import Response, request
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
|
||||
def check_auth(token):
|
||||
password = None
|
||||
if token.startswith('Basic '):
|
||||
try:
|
||||
_, password = basicauth.decode(token)
|
||||
except:
|
||||
return False
|
||||
elif token.startswith('Bearer '):
|
||||
password = token.split('Bearer ', maxsplit=1)
|
||||
del password[0]
|
||||
password = ''.join(password)
|
||||
print(password, token)
|
||||
return password == opts.admin_token
|
||||
|
||||
|
||||
def authenticate():
|
||||
"""Sends a 401 response that enables basic auth"""
|
||||
return Response(
|
||||
'AUTHENTICATION REQUIRED', 401,
|
||||
{'WWW-Authenticate': 'Basic realm="Login Required"'})
|
||||
|
||||
|
||||
def requires_auth(f):
|
||||
@wraps(f)
|
||||
def decorated(*args, **kwargs):
|
||||
auth = request.headers.get('Authorization')
|
||||
if not auth or not check_auth(auth):
|
||||
return authenticate()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return decorated
|
|
@ -3,7 +3,7 @@ from llm_server.routes.cache import redis
|
|||
|
||||
|
||||
def format_sillytavern_err(msg: str, level: str = 'info'):
|
||||
http_host = redis.get('http_host')
|
||||
http_host = redis.get('http_host', str)
|
||||
return f"""```
|
||||
=== MESSAGE FROM LLM MIDDLEWARE AT {http_host} ===
|
||||
-> {level.upper()} <-
|
||||
|
|
|
@ -35,7 +35,7 @@ def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, act
|
|||
|
||||
# TODO: have routes/__init__.py point to the latest API version generate_stats()
|
||||
|
||||
@cache.memoize(timeout=10)
|
||||
@cache.memoize(timeout=20)
|
||||
def generate_stats():
|
||||
model_name, error = get_running_model() # will return False when the fetch fails
|
||||
if isinstance(model_name, bool):
|
||||
|
|
|
@ -102,7 +102,7 @@ def stream(ws):
|
|||
generated_tokens = tokenize(generated_text)
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
|
||||
except:
|
||||
generated_text = generated_text + '\n\n' + format_sillytavern_err('encountered error while streaming', 'error')
|
||||
generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error')
|
||||
generated_tokens = tokenize(generated_text)
|
||||
traceback.print_exc()
|
||||
ws.send(json.dumps({
|
||||
|
|
|
@ -3,6 +3,7 @@ import time
|
|||
from flask import jsonify, request
|
||||
|
||||
from . import bp
|
||||
from ..auth import requires_auth
|
||||
from ..cache import cache
|
||||
from ... import opts
|
||||
from ...llm.info import get_running_model
|
||||
|
@ -40,3 +41,9 @@ def get_model():
|
|||
cache.set(cache_key, response, timeout=60)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@bp.route('/backend', methods=['GET'])
|
||||
@requires_auth
|
||||
def get_backend():
|
||||
return jsonify({'backend': opts.backend_url, 'mode': opts.mode}), 200
|
||||
|
|
|
@ -67,8 +67,5 @@ class MainBackgroundThread(Thread):
|
|||
|
||||
def cache_stats():
|
||||
while True:
|
||||
# If opts.base_client_api is null that means no one has visited the site yet
|
||||
# and the base_client_api hasn't been set. Do nothing until then.
|
||||
if redis.get('base_client_api'):
|
||||
x = generate_stats()
|
||||
time.sleep(5)
|
||||
|
|
|
@ -13,7 +13,10 @@ auto_gptq
|
|||
uvicorn~=0.23.2
|
||||
fastapi~=0.103.1
|
||||
torch~=2.0.1
|
||||
urllib3
|
||||
urllib3~=2.0.4
|
||||
PyMySQL~=1.1.0
|
||||
DBUtils~=3.0.3
|
||||
simplejson
|
||||
simplejson~=3.19.1
|
||||
setuptools~=65.5.1
|
||||
websockets~=11.0.3
|
||||
basicauth~=1.0.0
|
11
server.py
11
server.py
|
@ -62,6 +62,7 @@ if config['mode'] not in ['oobabooga', 'vllm']:
|
|||
flushed_keys = redis.flush()
|
||||
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
||||
|
||||
# TODO: this is a MESS
|
||||
opts.mode = config['mode']
|
||||
opts.auth_required = config['auth_required']
|
||||
opts.log_prompts = config['log_prompts']
|
||||
|
@ -83,6 +84,7 @@ opts.openai_system_prompt = config['openai_system_prompt']
|
|||
opts.expose_openai_system_prompt = config['expose_openai_system_prompt']
|
||||
opts.enable_streaming = config['enable_streaming']
|
||||
opts.openai_api_key = config['openai_api_key']
|
||||
opts.admin_token = config['admin_token']
|
||||
|
||||
if config['http_host']:
|
||||
redis.set('http_host', config['http_host'])
|
||||
|
@ -113,6 +115,10 @@ elif opts.mode == 'vllm':
|
|||
else:
|
||||
raise Exception
|
||||
|
||||
app = Flask(__name__)
|
||||
cache.init_app(app)
|
||||
cache.clear() # clear redis cache
|
||||
|
||||
# Start background processes
|
||||
start_workers(opts.concurrent_gens)
|
||||
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
|
||||
|
@ -121,9 +127,6 @@ process_avg_gen_time_background_thread.start()
|
|||
MainBackgroundThread().start()
|
||||
SemaphoreCheckerThread().start()
|
||||
|
||||
app = Flask(__name__)
|
||||
cache.init_app(app)
|
||||
cache.clear() # clear redis cache
|
||||
init_socketio(app)
|
||||
app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||
|
@ -195,6 +198,8 @@ def home():
|
|||
)
|
||||
|
||||
|
||||
# TODO: add authenticated route to get the current backend URL. Add it to /v1/backend
|
||||
|
||||
@app.route('/<first>')
|
||||
@app.route('/<first>/<path:rest>')
|
||||
def fallback(first=None, rest=None):
|
||||
|
|
Reference in New Issue