This commit is contained in:
Cyberes 2023-08-21 21:28:52 -06:00
parent db0dfad83d
commit 8cbf643fd3
28 changed files with 510 additions and 2 deletions

4
.gitignore vendored
View File

@ -1,3 +1,6 @@
proxy-server.db
.idea
# ---> Python
# Byte-compiled / optimized / DLL files
__pycache__/
@ -159,4 +162,3 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

View File

@ -1,3 +1,3 @@
# local-llm-server
An HTTP API to serve local LLM Models.
_A HTTP API to serve local LLM Models._

10
config/config.yml Normal file
View File

@ -0,0 +1,10 @@
# TODO: add this file to gitignore and add a .sample.yml
log_prompts: true
mode: oobabooga
auth_required: false
backend_url: https://proxy.chub-archive.evulid.cc
database_path: ./proxy-server.db

0
llm_server/__init__.py Normal file
View File

39
llm_server/config.py Normal file
View File

@ -0,0 +1,39 @@
import yaml
class ConfigLoader:
"""
default_vars = {"var1": "default1", "var2": "default2"}
required_vars = ["var1", "var3"]
config_loader = ConfigLoader("config.yaml", default_vars, required_vars)
config = config_loader.load_config()
"""
def __init__(self, config_file, default_vars=None, required_vars=None):
self.config_file = config_file
self.default_vars = default_vars if default_vars else {}
self.required_vars = required_vars if required_vars else []
self.config = {}
def load_config(self) -> (bool, str | None, str | None):
with open(self.config_file, 'r') as stream:
try:
self.config = yaml.safe_load(stream)
except yaml.YAMLError as exc:
return False, None, exc
if self.config is None:
# Handle empty file
self.config = {}
# Set default variables if they are not present in the config file
for var, default_value in self.default_vars.items():
if var not in self.config:
self.config[var] = default_value
# Check if required variables are present in the config file
for var in self.required_vars:
if var not in self.config:
return False, None, f'Required variable "{var}" is missing from the config file'
return True, self.config, None

75
llm_server/database.py Normal file
View File

@ -0,0 +1,75 @@
import json
import sqlite3
import time
from pathlib import Path
import tiktoken
from llm_server import opts
tokenizer = tiktoken.get_encoding("cl100k_base")
def init_db(db_path):
if not Path(db_path).exists():
conn = sqlite3.connect(db_path)
c = conn.cursor()
c.execute('''
CREATE TABLE prompts (
ip TEXT,
token TEXT DEFAULT NULL,
prompt TEXT,
prompt_tokens INTEGER,
response TEXT,
response_tokens INTEGER,
parameters TEXT CHECK (parameters IS NULL OR json_valid(parameters)),
headers TEXT CHECK (headers IS NULL OR json_valid(headers)),
timestamp INTEGER
)
''')
c.execute('''
CREATE TABLE token_auth
(token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER)
''')
conn.commit()
conn.close()
def log_prompt(db_path, ip, token, prompt, response, parameters, headers):
prompt_tokens = len(tokenizer.encode(prompt))
response_tokens = len(tokenizer.encode(response))
if not opts.log_prompts:
prompt = response = None
timestamp = int(time.time())
conn = sqlite3.connect(db_path)
c = conn.cursor()
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(ip, token, prompt, prompt_tokens, response, response_tokens, json.dumps(parameters), json.dumps(headers), timestamp))
conn.commit()
conn.close()
def is_valid_api_key(api_key):
conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor()
cursor.execute("SELECT token, uses, max_uses, expire FROM token_auth WHERE token = ?", (api_key,))
row = cursor.fetchone()
if row is not None:
token, uses, max_uses, expire = row
if (uses is None or uses < max_uses) and (expire is None or expire > time.time()):
return True
return False
def increment_uses(api_key):
conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor()
cursor.execute("SELECT token FROM token_auth WHERE token = ?", (api_key,))
row = cursor.fetchone()
if row is not None:
cursor.execute("UPDATE token_auth SET uses = COALESCE(uses, 0) + 1 WHERE token = ?", (api_key,))
conn.commit()
return True
return False

5
llm_server/helpers.py Normal file
View File

@ -0,0 +1,5 @@
from pathlib import Path
def resolve_path(*p: str):
return Path(*p).expanduser().resolve().absolute()

12
llm_server/integer.py Normal file
View File

@ -0,0 +1,12 @@
import threading
class ThreadSafeInteger:
def __init__(self, value=0):
self.value = value
self._value_lock = threading.Lock()
def increment(self):
with self._value_lock:
self.value += 1
return self.value

View File

View File

View File

@ -0,0 +1,36 @@
import requests
from flask import current_app
from llm_server import opts
def prepare_json(json_data: dict):
token_count = len(current_app.tokenizer.encode(json_data.get('prompt', '')))
seed = json_data.get('seed', None)
if seed == -1:
seed = None
return {
'inputs': json_data.get('prompt', ''),
'parameters': {
'max_new_tokens': token_count - opts.token_limit,
'repetition_penalty': json_data.get('repetition_penalty', None),
'seed': seed,
'stop': json_data.get('stopping_strings', []),
'temperature': json_data.get('temperature', None),
'top_k': json_data.get('top_k', None),
'top_p': json_data.get('top_p', None),
'truncate': True,
'typical_p': json_data.get('typical_p', None),
'watermark': False
}
}
def generate(json_data: dict):
try:
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data))
except Exception as e:
return False, None, f'{e.__class__.__name__}: {e}'
if r.status_code != 200:
return False, r, f'Backend returned {r.status_code}'
return True, r, None

View File

@ -0,0 +1 @@
# Start the backend server

View File

View File

@ -0,0 +1,13 @@
import requests
from llm_server import opts
def generate(json_data: dict):
try:
r = requests.post(f'{opts.backend_url}/api/v1/generate', json=json_data)
except Exception as e:
return False, None, f'{e.__class__.__name__}: {e}'
if r.status_code != 200:
return False, r, f'Backend returned {r.status_code}'
return True, r, None

View File

@ -0,0 +1,15 @@
import requests
from llm_server import opts
def get_running_model():
try:
backend_response = requests.get(f'{opts.backend_url}/api/v1/model')
except Exception as e:
return False
try:
r_json = backend_response.json()
return r_json['result']
except Exception as e:
return False

View File

8
llm_server/opts.py Normal file
View File

@ -0,0 +1,8 @@
running_model = 'none'
concurrent_generates = 3
mode = 'oobabooga'
backend_url = None
token_limit = 5555
database_path = './proxy-server.db'
auth_required = False
log_prompts = False

View File

View File

@ -0,0 +1,3 @@
from flask_caching import Cache
cache = Cache(config={'CACHE_TYPE': 'simple'})

View File

View File

@ -0,0 +1,50 @@
import json
from typing import Union
from flask import make_response
from requests import Response
from flask import request, jsonify
from functools import wraps
from llm_server import opts
from llm_server.database import is_valid_api_key
def require_api_key():
if not opts.auth_required:
return
elif 'X-Api-Key' in request.headers:
if is_valid_api_key(request.headers['X-Api-Key']):
return
else:
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
else:
return jsonify({'code': 401, 'message': 'API key required'}), 401
# def cache_control(seconds):
# def decorator(f):
# @wraps(f)
# def decorated_function(*args, **kwargs):
# resp = make_response(f(*args, **kwargs))
# resp.headers['Cache-Control'] = f'public, max-age={seconds}'
# return resp
#
# return decorated_function
#
# return decorator
def validate_json(data: Union[str, Response]):
if isinstance(data, Response):
try:
data = data.json()
return True, data
except Exception as e:
return False, None
try:
j = json.loads(data)
return True, j
except Exception as e:
return False, None

View File

@ -0,0 +1,10 @@
from datetime import datetime
from threading import Semaphore
from llm_server.integer import ThreadSafeInteger
from llm_server.opts import concurrent_generates
concurrent_semaphore = Semaphore(concurrent_generates)
proompts = ThreadSafeInteger(0)
start_time = datetime.now()

View File

@ -0,0 +1,18 @@
from flask import Blueprint, request
from ..helpers.http import require_api_key
bp = Blueprint('v1', __name__)
# openai_bp = Blueprint('/v1', __name__)
@bp.before_request
def before_request():
if request.endpoint != 'v1.get_stats':
response = require_api_key()
if response is not None:
return response
from . import generate, info, proxy

View File

@ -0,0 +1,64 @@
from flask import jsonify, request
from . import bp
from llm_server.routes.stats import concurrent_semaphore, proompts
from ..helpers.http import validate_json
from ... import opts
from ...database import log_prompt
if opts.mode == 'oobabooga':
from ...llm.oobabooga.generate import generate
generator = generate
elif opts.mode == 'hf-textgen':
from ...llm.hf_textgen.generate import generate
generator = generate
@bp.route('/generate', methods=['POST'])
def generate():
request_valid_json, request_json_body = validate_json(request.data)
if not request_valid_json:
return jsonify({'code': 400, 'error': 'Invalid JSON'}), 400
with concurrent_semaphore:
success, response, error_msg = generator(request_json_body)
if not success:
return jsonify({
'code': 500,
'error': 'failed to reach backend'
}), 500
response_valid_json, response_json_body = validate_json(response)
if response_valid_json:
proompts.increment()
# request.headers = {"host": "proxy.chub-archive.evulid.cc", "x-forwarded-proto": "https", "user-agent": "node-fetch/1.0 (+https://github.com/bitinn/node-fetch)", "cf-visitor": {"scheme": "https"}, "cf-ipcountry": "CH", "accept": "*/*", "accept-encoding": "gzip",
# "x-forwarded-for": "193.32.127.228", "cf-ray": "7fa72c6a6d5cbba7-FRA", "cf-connecting-ip": "193.32.127.228", "cdn-loop": "cloudflare", "content-type": "application/json", "content-length": "9039"}
if request.headers.get('cf-connecting-ip'):
client_ip = request.headers.get('cf-connecting-ip')
elif request.headers.get('x-forwarded-for'):
client_ip = request.headers.get('x-forwarded-for')
else:
client_ip = request.remote_addr
parameters = request_json_body.copy()
del parameters['prompt']
token = request.headers.get('X-Api-Key')
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], response_json_body['results'][0]['text'], parameters, dict(request.headers))
return jsonify({
**response_json_body
}), 200
else:
return jsonify({
'code': 500,
'error': 'failed to reach backend'
}), 500
# @openai_bp.route('/chat/completions', methods=['POST'])
# def generate_openai():
# print(request.data)
# return '', 200

View File

@ -0,0 +1,57 @@
import time
from flask import jsonify
from . import bp
from ...llm.oobabooga.info import get_running_model
from ..cache import cache
# cache = Cache(bp, config={'CACHE_TYPE': 'simple'})
# @bp.route('/info', methods=['GET'])
# # @cache.cached(timeout=3600, query_string=True)
# def get_info():
# # requests.get()
# return 'yes'
@bp.route('/model', methods=['GET'])
@cache.cached(timeout=60, query_string=True)
def get_model():
model = get_running_model()
if not model:
return jsonify({
'code': 500,
'error': 'failed to reach backend'
}), 500
return jsonify({
'result': model,
'timestamp': int(time.time())
}), 200
# @openai_bp.route('/models', methods=['GET'])
# # @cache.cached(timeout=3600, query_string=True)
# def get_openai_models():
# model = get_running_model()
# return {
# "object": "list",
# "data": [{
# "id": model,
# "object": "model",
# "created": stats.start_time,
# "owned_by": "openai",
# "permission": [{
# "id": f"modelperm-{model}",
# "object": "model_permission",
# "created": stats.start_time,
# "organization": "*",
# "group": None,
# "is_blocking": False
# }],
# "root": model,
# "parent": None
# }]
# }

View File

@ -0,0 +1,21 @@
import time
from datetime import datetime
from flask import jsonify
from llm_server import opts
from . import bp
from .. import stats
from llm_server.routes.v1.generate import concurrent_semaphore
from ..cache import cache
@bp.route('/stats', methods=['GET'])
@cache.cached(timeout=60, query_string=True)
def get_stats():
return jsonify({
'proompters_now': opts.concurrent_generates - concurrent_semaphore._value,
'total_proompts': stats.proompts.value,
'uptime': int((datetime.now() - stats.start_time).total_seconds()),
'timestamp': int(time.time())
}), 200

6
requirements.txt Normal file
View File

@ -0,0 +1,6 @@
flask
flask_cors
pyyaml
flask_caching
requests
tiktoken

63
server.py Normal file
View File

@ -0,0 +1,63 @@
import os
import sys
from pathlib import Path
import tiktoken
from flask import Flask, current_app, jsonify
from llm_server import opts
from llm_server.config import ConfigLoader
from llm_server.database import init_db
from llm_server.helpers import resolve_path
from llm_server.llm.oobabooga.info import get_running_model
from llm_server.routes.cache import cache
from llm_server.routes.v1 import bp
config_path_environ = os.getenv("CONFIG_PATH")
if config_path_environ:
config_path = config_path_environ
else:
config_path = Path(os.path.dirname(os.path.realpath(__file__)), 'config', 'config.yml')
default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False}
required_vars = []
config_loader = ConfigLoader(config_path, default_vars, required_vars)
success, config, msg = config_loader.load_config()
if not success:
print('Failed to load config:', msg)
sys.exit(1)
opts.backend_url = config['backend_url'].strip('/')
opts.database_path = resolve_path(config['database_path'])
init_db(opts.database_path)
if config['mode'] not in ['oobabooga', 'hf-textgen']:
print('Unknown mode:', config['mode'])
opts.mode = config['mode']
opts.auth_required = config['auth_required']
opts.log_prompts = config['log_prompts']
opts.running_model = get_running_model()
app = Flask(__name__)
cache.init_app(app)
# with app.app_context():
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
app.register_blueprint(bp, url_prefix='/api/v1/')
print(app.url_map)
@app.route('/')
@app.route('/<first>')
@app.route('/<first>/<path:rest>')
def fallback(first=None, rest=None):
return jsonify({
'error': 404,
'msg': 'not found'
}), 404
if __name__ == "__main__":
app.run(host='0.0.0.0')