prototype hf-textgen and adjust logging

This commit is contained in:
Cyberes 2023-08-22 19:58:31 -06:00
parent a59dcea2da
commit 0d32db2dbd
10 changed files with 75 additions and 40 deletions

View File

@ -5,6 +5,7 @@ log_prompts: true
mode: oobabooga mode: oobabooga
auth_required: false auth_required: false
concurrent_gens: 3 concurrent_gens: 3
token_limit: 5555
backend_url: http://172.0.0.2:9104 backend_url: http://172.0.0.2:9104

View File

@ -22,6 +22,7 @@ def init_db(db_path):
prompt_tokens INTEGER, prompt_tokens INTEGER,
response TEXT, response TEXT,
response_tokens INTEGER, response_tokens INTEGER,
response_status INTEGER,
parameters TEXT CHECK (parameters IS NULL OR json_valid(parameters)), parameters TEXT CHECK (parameters IS NULL OR json_valid(parameters)),
headers TEXT CHECK (headers IS NULL OR json_valid(headers)), headers TEXT CHECK (headers IS NULL OR json_valid(headers)),
timestamp INTEGER timestamp INTEGER
@ -31,11 +32,15 @@ def init_db(db_path):
CREATE TABLE token_auth CREATE TABLE token_auth
(token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER, disabled BOOLEAN default 0) (token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER, disabled BOOLEAN default 0)
''') ''')
# c.execute('''
# CREATE TABLE leeches
# (url TEXT, online TEXT)
# ''')
conn.commit() conn.commit()
conn.close() conn.close()
def log_prompt(db_path, ip, token, prompt, response, parameters, headers): def log_prompt(db_path, ip, token, prompt, response, parameters, headers, backend_response_code):
prompt_tokens = len(tokenizer.encode(prompt)) prompt_tokens = len(tokenizer.encode(prompt))
response_tokens = len(tokenizer.encode(response)) response_tokens = len(tokenizer.encode(response))
@ -45,8 +50,8 @@ def log_prompt(db_path, ip, token, prompt, response, parameters, headers):
timestamp = int(time.time()) timestamp = int(time.time())
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)
c = conn.cursor() c = conn.cursor()
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(ip, token, prompt, prompt_tokens, response, response_tokens, json.dumps(parameters), json.dumps(headers), timestamp)) (ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
conn.commit() conn.commit()
conn.close() conn.close()

View File

@ -1,36 +1,44 @@
import json
import requests import requests
from flask import current_app from flask import current_app
from llm_server import opts from llm_server import opts
from llm_server.database import tokenizer
def prepare_json(json_data: dict): def prepare_json(json_data: dict):
token_count = len(current_app.tokenizer.encode(json_data.get('prompt', ''))) token_count = len(tokenizer.encode(json_data.get('prompt', '')))
seed = json_data.get('seed', None) seed = json_data.get('seed', None)
if seed == -1: if seed == -1:
seed = None seed = None
typical_p = json_data.get('typical_p', None)
if typical_p >= 1:
typical_p = 0.999
return { return {
'inputs': json_data.get('prompt', ''), 'inputs': json_data.get('prompt', ''),
'parameters': { 'parameters': {
'max_new_tokens': token_count - opts.token_limit, 'max_new_tokens': opts.token_limit - token_count,
'repetition_penalty': json_data.get('repetition_penalty', None), 'repetition_penalty': json_data.get('repetition_penalty', None),
'seed': seed, 'seed': seed,
'stop': json_data.get('stopping_strings', []), 'stop': json_data.get('stopping_strings', []),
'temperature': json_data.get('temperature', None), 'temperature': json_data.get('temperature', None),
'top_k': json_data.get('top_k', None), 'top_k': json_data.get('top_k', None),
'top_p': json_data.get('top_p', None), 'top_p': json_data.get('top_p', None),
'truncate': True, # 'truncate': opts.token_limit,
'typical_p': json_data.get('typical_p', None), 'typical_p': typical_p,
'watermark': False 'watermark': False
} }
} }
def generate(json_data: dict): def generate(json_data: dict):
try: print(json.dumps(prepare_json(json_data)))
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data)) # try:
except Exception as e: r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data))
return False, None, f'{e.__class__.__name__}: {e}' print(r.text)
if r.status_code != 200: # except Exception as e:
return False, r, f'Backend returned {r.status_code}' # return False, None, f'{e.__class__.__name__}: {e}'
return True, r, None # if r.status_code != 200:
# return False, r, f'Backend returned {r.status_code}'
# return True, r, None

28
llm_server/llm/info.py Normal file
View File

@ -0,0 +1,28 @@
import requests
from llm_server import opts
def get_running_model():
if opts.mode == 'oobabooga':
try:
backend_response = requests.get(f'{opts.backend_url}/api/v1/model')
except Exception as e:
return False
try:
r_json = backend_response.json()
return r_json['result']
except Exception as e:
return False
elif opts.mode == 'hf-textgen':
try:
backend_response = requests.get(f'{opts.backend_url}/info')
except Exception as e:
return False
try:
r_json = backend_response.json()
return r_json['model_id'].replace('/', '_')
except Exception as e:
return False
else:
raise Exception

View File

@ -1,15 +1,3 @@
import requests
from llm_server import opts
def get_running_model():
try:
backend_response = requests.get(f'{opts.backend_url}/api/v1/model')
except Exception as e:
return False
try:
r_json = backend_response.json()
return r_json['result']
except Exception as e:
return False

View File

@ -6,14 +6,16 @@ from ..helpers.http import cache_control, validate_json
from ... import opts from ... import opts
from ...database import log_prompt from ...database import log_prompt
if opts.mode == 'oobabooga':
from ...llm.oobabooga.generate import generate
generator = generate def generator(request_json_body):
elif opts.mode == 'hf-textgen': if opts.mode == 'oobabooga':
from ...llm.hf_textgen.generate import generate from ...llm.oobabooga.generate import generate
return generate(request_json_body)
generator = generate elif opts.mode == 'hf-textgen':
from ...llm.hf_textgen.generate import generate
return generate(request_json_body)
else:
raise Exception
@bp.route('/generate', methods=['POST']) @bp.route('/generate', methods=['POST'])
@ -49,7 +51,7 @@ def generate():
token = request.headers.get('X-Api-Key') token = request.headers.get('X-Api-Key')
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], response_json_body['results'][0]['text'], parameters, dict(request.headers)) log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], response_json_body['results'][0]['text'], parameters, dict(request.headers), response.status_code)
return jsonify({ return jsonify({
**response_json_body **response_json_body
}), 200 }), 200

View File

@ -4,7 +4,8 @@ from flask import jsonify
from . import bp from . import bp
from ..helpers.http import cache_control from ..helpers.http import cache_control
from ...llm.oobabooga.info import get_running_model from ... import opts
from ...llm.info import get_running_model
from ..cache import cache from ..cache import cache

View File

@ -10,7 +10,7 @@ from .. import stats
from ..cache import cache from ..cache import cache
from ..helpers.http import cache_control from ..helpers.http import cache_control
from ..stats import proompters_1_min from ..stats import proompters_1_min
from ...llm.oobabooga.info import get_running_model from ...llm.info import get_running_model
@bp.route('/stats', methods=['GET']) @bp.route('/stats', methods=['GET'])

View File

@ -9,7 +9,7 @@ from llm_server import opts
from llm_server.config import ConfigLoader from llm_server.config import ConfigLoader
from llm_server.database import init_db from llm_server.database import init_db
from llm_server.helpers import resolve_path from llm_server.helpers import resolve_path
from llm_server.llm.oobabooga.info import get_running_model from llm_server.llm.info import get_running_model
from llm_server.routes.cache import cache from llm_server.routes.cache import cache
from llm_server.routes.helpers.http import cache_control from llm_server.routes.helpers.http import cache_control
from llm_server.routes.v1 import bp from llm_server.routes.v1 import bp
@ -23,7 +23,7 @@ else:
config_path = Path(script_path, 'config', 'config.yml') config_path = Path(script_path, 'config', 'config.yml')
default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False, 'concurrent_gens': 3, 'frontend_api_client': ''} default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False, 'concurrent_gens': 3, 'frontend_api_client': ''}
required_vars = [] required_vars = ['token_limit']
config_loader = ConfigLoader(config_path, default_vars, required_vars) config_loader = ConfigLoader(config_path, default_vars, required_vars)
success, config, msg = config_loader.load_config() success, config, msg = config_loader.load_config()
if not success: if not success:
@ -46,9 +46,11 @@ opts.auth_required = config['auth_required']
opts.log_prompts = config['log_prompts'] opts.log_prompts = config['log_prompts']
opts.concurrent_gens = config['concurrent_gens'] opts.concurrent_gens = config['concurrent_gens']
opts.frontend_api_client = config['frontend_api_client'] opts.frontend_api_client = config['frontend_api_client']
opts.token_limit = config['token_limit']
app = Flask(__name__) app = Flask(__name__)
cache.init_app(app) cache.init_app(app)
cache.clear() # clear redis cache
# with app.app_context(): # with app.app_context():
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base") # current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
app.register_blueprint(bp, url_prefix='/api/v1/') app.register_blueprint(bp, url_prefix='/api/v1/')
@ -69,4 +71,4 @@ def fallback(first=None, rest=None):
if __name__ == "__main__": if __name__ == "__main__":
app.run(host='0.0.0.0') app.run(host='0.0.0.0', debug=True)