MVP
This commit is contained in:
parent
db0dfad83d
commit
8cbf643fd3
|
@ -1,3 +1,6 @@
|
|||
proxy-server.db
|
||||
.idea
|
||||
|
||||
# ---> Python
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
@ -159,4 +162,3 @@ cython_debug/
|
|||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
# local-llm-server
|
||||
|
||||
An HTTP API to serve local LLM Models.
|
||||
_A HTTP API to serve local LLM Models._
|
|
@ -0,0 +1,10 @@
|
|||
# TODO: add this file to gitignore and add a .sample.yml
|
||||
|
||||
log_prompts: true
|
||||
|
||||
mode: oobabooga
|
||||
auth_required: false
|
||||
|
||||
backend_url: https://proxy.chub-archive.evulid.cc
|
||||
|
||||
database_path: ./proxy-server.db
|
|
@ -0,0 +1,39 @@
|
|||
import yaml
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
"""
|
||||
default_vars = {"var1": "default1", "var2": "default2"}
|
||||
required_vars = ["var1", "var3"]
|
||||
config_loader = ConfigLoader("config.yaml", default_vars, required_vars)
|
||||
config = config_loader.load_config()
|
||||
"""
|
||||
|
||||
def __init__(self, config_file, default_vars=None, required_vars=None):
|
||||
self.config_file = config_file
|
||||
self.default_vars = default_vars if default_vars else {}
|
||||
self.required_vars = required_vars if required_vars else []
|
||||
self.config = {}
|
||||
|
||||
def load_config(self) -> (bool, str | None, str | None):
|
||||
with open(self.config_file, 'r') as stream:
|
||||
try:
|
||||
self.config = yaml.safe_load(stream)
|
||||
except yaml.YAMLError as exc:
|
||||
return False, None, exc
|
||||
|
||||
if self.config is None:
|
||||
# Handle empty file
|
||||
self.config = {}
|
||||
|
||||
# Set default variables if they are not present in the config file
|
||||
for var, default_value in self.default_vars.items():
|
||||
if var not in self.config:
|
||||
self.config[var] = default_value
|
||||
|
||||
# Check if required variables are present in the config file
|
||||
for var in self.required_vars:
|
||||
if var not in self.config:
|
||||
return False, None, f'Required variable "{var}" is missing from the config file'
|
||||
|
||||
return True, self.config, None
|
|
@ -0,0 +1,75 @@
|
|||
import json
|
||||
import sqlite3
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import tiktoken
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
def init_db(db_path):
|
||||
if not Path(db_path).exists():
|
||||
conn = sqlite3.connect(db_path)
|
||||
c = conn.cursor()
|
||||
c.execute('''
|
||||
CREATE TABLE prompts (
|
||||
ip TEXT,
|
||||
token TEXT DEFAULT NULL,
|
||||
prompt TEXT,
|
||||
prompt_tokens INTEGER,
|
||||
response TEXT,
|
||||
response_tokens INTEGER,
|
||||
parameters TEXT CHECK (parameters IS NULL OR json_valid(parameters)),
|
||||
headers TEXT CHECK (headers IS NULL OR json_valid(headers)),
|
||||
timestamp INTEGER
|
||||
)
|
||||
''')
|
||||
c.execute('''
|
||||
CREATE TABLE token_auth
|
||||
(token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER)
|
||||
''')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def log_prompt(db_path, ip, token, prompt, response, parameters, headers):
|
||||
prompt_tokens = len(tokenizer.encode(prompt))
|
||||
response_tokens = len(tokenizer.encode(response))
|
||||
|
||||
if not opts.log_prompts:
|
||||
prompt = response = None
|
||||
|
||||
timestamp = int(time.time())
|
||||
conn = sqlite3.connect(db_path)
|
||||
c = conn.cursor()
|
||||
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(ip, token, prompt, prompt_tokens, response, response_tokens, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def is_valid_api_key(api_key):
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT token, uses, max_uses, expire FROM token_auth WHERE token = ?", (api_key,))
|
||||
row = cursor.fetchone()
|
||||
if row is not None:
|
||||
token, uses, max_uses, expire = row
|
||||
if (uses is None or uses < max_uses) and (expire is None or expire > time.time()):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def increment_uses(api_key):
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT token FROM token_auth WHERE token = ?", (api_key,))
|
||||
row = cursor.fetchone()
|
||||
if row is not None:
|
||||
cursor.execute("UPDATE token_auth SET uses = COALESCE(uses, 0) + 1 WHERE token = ?", (api_key,))
|
||||
conn.commit()
|
||||
return True
|
||||
return False
|
|
@ -0,0 +1,5 @@
|
|||
from pathlib import Path
|
||||
|
||||
|
||||
def resolve_path(*p: str):
|
||||
return Path(*p).expanduser().resolve().absolute()
|
|
@ -0,0 +1,12 @@
|
|||
import threading
|
||||
|
||||
|
||||
class ThreadSafeInteger:
|
||||
def __init__(self, value=0):
|
||||
self.value = value
|
||||
self._value_lock = threading.Lock()
|
||||
|
||||
def increment(self):
|
||||
with self._value_lock:
|
||||
self.value += 1
|
||||
return self.value
|
|
@ -0,0 +1,36 @@
|
|||
import requests
|
||||
from flask import current_app
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
|
||||
def prepare_json(json_data: dict):
|
||||
token_count = len(current_app.tokenizer.encode(json_data.get('prompt', '')))
|
||||
seed = json_data.get('seed', None)
|
||||
if seed == -1:
|
||||
seed = None
|
||||
return {
|
||||
'inputs': json_data.get('prompt', ''),
|
||||
'parameters': {
|
||||
'max_new_tokens': token_count - opts.token_limit,
|
||||
'repetition_penalty': json_data.get('repetition_penalty', None),
|
||||
'seed': seed,
|
||||
'stop': json_data.get('stopping_strings', []),
|
||||
'temperature': json_data.get('temperature', None),
|
||||
'top_k': json_data.get('top_k', None),
|
||||
'top_p': json_data.get('top_p', None),
|
||||
'truncate': True,
|
||||
'typical_p': json_data.get('typical_p', None),
|
||||
'watermark': False
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def generate(json_data: dict):
|
||||
try:
|
||||
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data))
|
||||
except Exception as e:
|
||||
return False, None, f'{e.__class__.__name__}: {e}'
|
||||
if r.status_code != 200:
|
||||
return False, r, f'Backend returned {r.status_code}'
|
||||
return True, r, None
|
|
@ -0,0 +1 @@
|
|||
# Start the backend server
|
|
@ -0,0 +1,13 @@
|
|||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
|
||||
def generate(json_data: dict):
|
||||
try:
|
||||
r = requests.post(f'{opts.backend_url}/api/v1/generate', json=json_data)
|
||||
except Exception as e:
|
||||
return False, None, f'{e.__class__.__name__}: {e}'
|
||||
if r.status_code != 200:
|
||||
return False, r, f'Backend returned {r.status_code}'
|
||||
return True, r, None
|
|
@ -0,0 +1,15 @@
|
|||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
|
||||
def get_running_model():
|
||||
try:
|
||||
backend_response = requests.get(f'{opts.backend_url}/api/v1/model')
|
||||
except Exception as e:
|
||||
return False
|
||||
try:
|
||||
r_json = backend_response.json()
|
||||
return r_json['result']
|
||||
except Exception as e:
|
||||
return False
|
|
@ -0,0 +1,8 @@
|
|||
running_model = 'none'
|
||||
concurrent_generates = 3
|
||||
mode = 'oobabooga'
|
||||
backend_url = None
|
||||
token_limit = 5555
|
||||
database_path = './proxy-server.db'
|
||||
auth_required = False
|
||||
log_prompts = False
|
|
@ -0,0 +1,3 @@
|
|||
from flask_caching import Cache
|
||||
|
||||
cache = Cache(config={'CACHE_TYPE': 'simple'})
|
|
@ -0,0 +1,50 @@
|
|||
import json
|
||||
from typing import Union
|
||||
|
||||
from flask import make_response
|
||||
from requests import Response
|
||||
|
||||
from flask import request, jsonify
|
||||
from functools import wraps
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import is_valid_api_key
|
||||
|
||||
|
||||
def require_api_key():
|
||||
if not opts.auth_required:
|
||||
return
|
||||
elif 'X-Api-Key' in request.headers:
|
||||
if is_valid_api_key(request.headers['X-Api-Key']):
|
||||
return
|
||||
else:
|
||||
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
||||
else:
|
||||
return jsonify({'code': 401, 'message': 'API key required'}), 401
|
||||
|
||||
|
||||
# def cache_control(seconds):
|
||||
# def decorator(f):
|
||||
# @wraps(f)
|
||||
# def decorated_function(*args, **kwargs):
|
||||
# resp = make_response(f(*args, **kwargs))
|
||||
# resp.headers['Cache-Control'] = f'public, max-age={seconds}'
|
||||
# return resp
|
||||
#
|
||||
# return decorated_function
|
||||
#
|
||||
# return decorator
|
||||
|
||||
|
||||
def validate_json(data: Union[str, Response]):
|
||||
if isinstance(data, Response):
|
||||
try:
|
||||
data = data.json()
|
||||
return True, data
|
||||
except Exception as e:
|
||||
return False, None
|
||||
try:
|
||||
j = json.loads(data)
|
||||
return True, j
|
||||
except Exception as e:
|
||||
return False, None
|
|
@ -0,0 +1,10 @@
|
|||
from datetime import datetime
|
||||
from threading import Semaphore
|
||||
|
||||
from llm_server.integer import ThreadSafeInteger
|
||||
from llm_server.opts import concurrent_generates
|
||||
|
||||
concurrent_semaphore = Semaphore(concurrent_generates)
|
||||
proompts = ThreadSafeInteger(0)
|
||||
start_time = datetime.now()
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
from flask import Blueprint, request
|
||||
|
||||
from ..helpers.http import require_api_key
|
||||
|
||||
bp = Blueprint('v1', __name__)
|
||||
|
||||
|
||||
# openai_bp = Blueprint('/v1', __name__)
|
||||
|
||||
@bp.before_request
|
||||
def before_request():
|
||||
if request.endpoint != 'v1.get_stats':
|
||||
response = require_api_key()
|
||||
if response is not None:
|
||||
return response
|
||||
|
||||
|
||||
from . import generate, info, proxy
|
|
@ -0,0 +1,64 @@
|
|||
from flask import jsonify, request
|
||||
|
||||
from . import bp
|
||||
from llm_server.routes.stats import concurrent_semaphore, proompts
|
||||
from ..helpers.http import validate_json
|
||||
from ... import opts
|
||||
from ...database import log_prompt
|
||||
|
||||
if opts.mode == 'oobabooga':
|
||||
from ...llm.oobabooga.generate import generate
|
||||
|
||||
generator = generate
|
||||
elif opts.mode == 'hf-textgen':
|
||||
from ...llm.hf_textgen.generate import generate
|
||||
|
||||
generator = generate
|
||||
|
||||
|
||||
@bp.route('/generate', methods=['POST'])
|
||||
def generate():
|
||||
request_valid_json, request_json_body = validate_json(request.data)
|
||||
if not request_valid_json:
|
||||
return jsonify({'code': 400, 'error': 'Invalid JSON'}), 400
|
||||
|
||||
with concurrent_semaphore:
|
||||
success, response, error_msg = generator(request_json_body)
|
||||
if not success:
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
'error': 'failed to reach backend'
|
||||
}), 500
|
||||
response_valid_json, response_json_body = validate_json(response)
|
||||
if response_valid_json:
|
||||
proompts.increment()
|
||||
|
||||
# request.headers = {"host": "proxy.chub-archive.evulid.cc", "x-forwarded-proto": "https", "user-agent": "node-fetch/1.0 (+https://github.com/bitinn/node-fetch)", "cf-visitor": {"scheme": "https"}, "cf-ipcountry": "CH", "accept": "*/*", "accept-encoding": "gzip",
|
||||
# "x-forwarded-for": "193.32.127.228", "cf-ray": "7fa72c6a6d5cbba7-FRA", "cf-connecting-ip": "193.32.127.228", "cdn-loop": "cloudflare", "content-type": "application/json", "content-length": "9039"}
|
||||
|
||||
if request.headers.get('cf-connecting-ip'):
|
||||
client_ip = request.headers.get('cf-connecting-ip')
|
||||
elif request.headers.get('x-forwarded-for'):
|
||||
client_ip = request.headers.get('x-forwarded-for')
|
||||
else:
|
||||
client_ip = request.remote_addr
|
||||
|
||||
parameters = request_json_body.copy()
|
||||
del parameters['prompt']
|
||||
|
||||
token = request.headers.get('X-Api-Key')
|
||||
|
||||
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], response_json_body['results'][0]['text'], parameters, dict(request.headers))
|
||||
return jsonify({
|
||||
**response_json_body
|
||||
}), 200
|
||||
else:
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
'error': 'failed to reach backend'
|
||||
}), 500
|
||||
|
||||
# @openai_bp.route('/chat/completions', methods=['POST'])
|
||||
# def generate_openai():
|
||||
# print(request.data)
|
||||
# return '', 200
|
|
@ -0,0 +1,57 @@
|
|||
import time
|
||||
|
||||
from flask import jsonify
|
||||
|
||||
from . import bp
|
||||
from ...llm.oobabooga.info import get_running_model
|
||||
from ..cache import cache
|
||||
|
||||
|
||||
# cache = Cache(bp, config={'CACHE_TYPE': 'simple'})
|
||||
|
||||
|
||||
# @bp.route('/info', methods=['GET'])
|
||||
# # @cache.cached(timeout=3600, query_string=True)
|
||||
# def get_info():
|
||||
# # requests.get()
|
||||
# return 'yes'
|
||||
|
||||
|
||||
@bp.route('/model', methods=['GET'])
|
||||
@cache.cached(timeout=60, query_string=True)
|
||||
def get_model():
|
||||
model = get_running_model()
|
||||
if not model:
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
'error': 'failed to reach backend'
|
||||
}), 500
|
||||
|
||||
return jsonify({
|
||||
'result': model,
|
||||
'timestamp': int(time.time())
|
||||
}), 200
|
||||
|
||||
# @openai_bp.route('/models', methods=['GET'])
|
||||
# # @cache.cached(timeout=3600, query_string=True)
|
||||
# def get_openai_models():
|
||||
# model = get_running_model()
|
||||
# return {
|
||||
# "object": "list",
|
||||
# "data": [{
|
||||
# "id": model,
|
||||
# "object": "model",
|
||||
# "created": stats.start_time,
|
||||
# "owned_by": "openai",
|
||||
# "permission": [{
|
||||
# "id": f"modelperm-{model}",
|
||||
# "object": "model_permission",
|
||||
# "created": stats.start_time,
|
||||
# "organization": "*",
|
||||
# "group": None,
|
||||
# "is_blocking": False
|
||||
# }],
|
||||
# "root": model,
|
||||
# "parent": None
|
||||
# }]
|
||||
# }
|
|
@ -0,0 +1,21 @@
|
|||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from flask import jsonify
|
||||
|
||||
from llm_server import opts
|
||||
from . import bp
|
||||
from .. import stats
|
||||
from llm_server.routes.v1.generate import concurrent_semaphore
|
||||
from ..cache import cache
|
||||
|
||||
|
||||
@bp.route('/stats', methods=['GET'])
|
||||
@cache.cached(timeout=60, query_string=True)
|
||||
def get_stats():
|
||||
return jsonify({
|
||||
'proompters_now': opts.concurrent_generates - concurrent_semaphore._value,
|
||||
'total_proompts': stats.proompts.value,
|
||||
'uptime': int((datetime.now() - stats.start_time).total_seconds()),
|
||||
'timestamp': int(time.time())
|
||||
}), 200
|
|
@ -0,0 +1,6 @@
|
|||
flask
|
||||
flask_cors
|
||||
pyyaml
|
||||
flask_caching
|
||||
requests
|
||||
tiktoken
|
|
@ -0,0 +1,63 @@
|
|||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import tiktoken
|
||||
from flask import Flask, current_app, jsonify
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.config import ConfigLoader
|
||||
from llm_server.database import init_db
|
||||
from llm_server.helpers import resolve_path
|
||||
from llm_server.llm.oobabooga.info import get_running_model
|
||||
from llm_server.routes.cache import cache
|
||||
from llm_server.routes.v1 import bp
|
||||
|
||||
config_path_environ = os.getenv("CONFIG_PATH")
|
||||
if config_path_environ:
|
||||
config_path = config_path_environ
|
||||
else:
|
||||
config_path = Path(os.path.dirname(os.path.realpath(__file__)), 'config', 'config.yml')
|
||||
|
||||
default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False}
|
||||
required_vars = []
|
||||
config_loader = ConfigLoader(config_path, default_vars, required_vars)
|
||||
success, config, msg = config_loader.load_config()
|
||||
if not success:
|
||||
print('Failed to load config:', msg)
|
||||
sys.exit(1)
|
||||
|
||||
opts.backend_url = config['backend_url'].strip('/')
|
||||
|
||||
opts.database_path = resolve_path(config['database_path'])
|
||||
init_db(opts.database_path)
|
||||
|
||||
if config['mode'] not in ['oobabooga', 'hf-textgen']:
|
||||
print('Unknown mode:', config['mode'])
|
||||
opts.mode = config['mode']
|
||||
opts.auth_required = config['auth_required']
|
||||
opts.log_prompts = config['log_prompts']
|
||||
|
||||
opts.running_model = get_running_model()
|
||||
|
||||
app = Flask(__name__)
|
||||
cache.init_app(app)
|
||||
# with app.app_context():
|
||||
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||
|
||||
print(app.url_map)
|
||||
|
||||
|
||||
@app.route('/')
|
||||
@app.route('/<first>')
|
||||
@app.route('/<first>/<path:rest>')
|
||||
def fallback(first=None, rest=None):
|
||||
return jsonify({
|
||||
'error': 404,
|
||||
'msg': 'not found'
|
||||
}), 404
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host='0.0.0.0')
|
Reference in New Issue