prototype hf-textgen and adjust logging

This commit is contained in:
Cyberes 2023-08-22 19:58:31 -06:00
parent a59dcea2da
commit 0d32db2dbd
10 changed files with 75 additions and 40 deletions

View File

@ -5,6 +5,7 @@ log_prompts: true
mode: oobabooga
auth_required: false
concurrent_gens: 3
token_limit: 5555
backend_url: http://172.0.0.2:9104

View File

@ -22,6 +22,7 @@ def init_db(db_path):
prompt_tokens INTEGER,
response TEXT,
response_tokens INTEGER,
response_status INTEGER,
parameters TEXT CHECK (parameters IS NULL OR json_valid(parameters)),
headers TEXT CHECK (headers IS NULL OR json_valid(headers)),
timestamp INTEGER
@ -31,11 +32,15 @@ def init_db(db_path):
CREATE TABLE token_auth
(token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER, disabled BOOLEAN default 0)
''')
# c.execute('''
# CREATE TABLE leeches
# (url TEXT, online TEXT)
# ''')
conn.commit()
conn.close()
def log_prompt(db_path, ip, token, prompt, response, parameters, headers):
def log_prompt(db_path, ip, token, prompt, response, parameters, headers, backend_response_code):
prompt_tokens = len(tokenizer.encode(prompt))
response_tokens = len(tokenizer.encode(response))
@ -45,8 +50,8 @@ def log_prompt(db_path, ip, token, prompt, response, parameters, headers):
timestamp = int(time.time())
conn = sqlite3.connect(db_path)
c = conn.cursor()
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(ip, token, prompt, prompt_tokens, response, response_tokens, json.dumps(parameters), json.dumps(headers), timestamp))
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
conn.commit()
conn.close()

View File

@ -1,36 +1,44 @@
import json
import requests
from flask import current_app
from llm_server import opts
from llm_server.database import tokenizer
def prepare_json(json_data: dict):
token_count = len(current_app.tokenizer.encode(json_data.get('prompt', '')))
token_count = len(tokenizer.encode(json_data.get('prompt', '')))
seed = json_data.get('seed', None)
if seed == -1:
seed = None
typical_p = json_data.get('typical_p', None)
if typical_p >= 1:
typical_p = 0.999
return {
'inputs': json_data.get('prompt', ''),
'parameters': {
'max_new_tokens': token_count - opts.token_limit,
'max_new_tokens': opts.token_limit - token_count,
'repetition_penalty': json_data.get('repetition_penalty', None),
'seed': seed,
'stop': json_data.get('stopping_strings', []),
'temperature': json_data.get('temperature', None),
'top_k': json_data.get('top_k', None),
'top_p': json_data.get('top_p', None),
'truncate': True,
'typical_p': json_data.get('typical_p', None),
# 'truncate': opts.token_limit,
'typical_p': typical_p,
'watermark': False
}
}
def generate(json_data: dict):
try:
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data))
except Exception as e:
return False, None, f'{e.__class__.__name__}: {e}'
if r.status_code != 200:
return False, r, f'Backend returned {r.status_code}'
return True, r, None
print(json.dumps(prepare_json(json_data)))
# try:
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data))
print(r.text)
# except Exception as e:
# return False, None, f'{e.__class__.__name__}: {e}'
# if r.status_code != 200:
# return False, r, f'Backend returned {r.status_code}'
# return True, r, None

28
llm_server/llm/info.py Normal file
View File

@ -0,0 +1,28 @@
import requests
from llm_server import opts
def get_running_model():
if opts.mode == 'oobabooga':
try:
backend_response = requests.get(f'{opts.backend_url}/api/v1/model')
except Exception as e:
return False
try:
r_json = backend_response.json()
return r_json['result']
except Exception as e:
return False
elif opts.mode == 'hf-textgen':
try:
backend_response = requests.get(f'{opts.backend_url}/info')
except Exception as e:
return False
try:
r_json = backend_response.json()
return r_json['model_id'].replace('/', '_')
except Exception as e:
return False
else:
raise Exception

View File

@ -1,15 +1,3 @@
import requests
from llm_server import opts
def get_running_model():
try:
backend_response = requests.get(f'{opts.backend_url}/api/v1/model')
except Exception as e:
return False
try:
r_json = backend_response.json()
return r_json['result']
except Exception as e:
return False

View File

@ -6,14 +6,16 @@ from ..helpers.http import cache_control, validate_json
from ... import opts
from ...database import log_prompt
if opts.mode == 'oobabooga':
from ...llm.oobabooga.generate import generate
generator = generate
elif opts.mode == 'hf-textgen':
from ...llm.hf_textgen.generate import generate
generator = generate
def generator(request_json_body):
if opts.mode == 'oobabooga':
from ...llm.oobabooga.generate import generate
return generate(request_json_body)
elif opts.mode == 'hf-textgen':
from ...llm.hf_textgen.generate import generate
return generate(request_json_body)
else:
raise Exception
@bp.route('/generate', methods=['POST'])
@ -49,7 +51,7 @@ def generate():
token = request.headers.get('X-Api-Key')
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], response_json_body['results'][0]['text'], parameters, dict(request.headers))
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], response_json_body['results'][0]['text'], parameters, dict(request.headers), response.status_code)
return jsonify({
**response_json_body
}), 200

View File

@ -4,7 +4,8 @@ from flask import jsonify
from . import bp
from ..helpers.http import cache_control
from ...llm.oobabooga.info import get_running_model
from ... import opts
from ...llm.info import get_running_model
from ..cache import cache

View File

@ -10,7 +10,7 @@ from .. import stats
from ..cache import cache
from ..helpers.http import cache_control
from ..stats import proompters_1_min
from ...llm.oobabooga.info import get_running_model
from ...llm.info import get_running_model
@bp.route('/stats', methods=['GET'])

View File

@ -9,7 +9,7 @@ from llm_server import opts
from llm_server.config import ConfigLoader
from llm_server.database import init_db
from llm_server.helpers import resolve_path
from llm_server.llm.oobabooga.info import get_running_model
from llm_server.llm.info import get_running_model
from llm_server.routes.cache import cache
from llm_server.routes.helpers.http import cache_control
from llm_server.routes.v1 import bp
@ -23,7 +23,7 @@ else:
config_path = Path(script_path, 'config', 'config.yml')
default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False, 'concurrent_gens': 3, 'frontend_api_client': ''}
required_vars = []
required_vars = ['token_limit']
config_loader = ConfigLoader(config_path, default_vars, required_vars)
success, config, msg = config_loader.load_config()
if not success:
@ -46,9 +46,11 @@ opts.auth_required = config['auth_required']
opts.log_prompts = config['log_prompts']
opts.concurrent_gens = config['concurrent_gens']
opts.frontend_api_client = config['frontend_api_client']
opts.token_limit = config['token_limit']
app = Flask(__name__)
cache.init_app(app)
cache.clear() # clear redis cache
# with app.app_context():
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
app.register_blueprint(bp, url_prefix='/api/v1/')
@ -69,4 +71,4 @@ def fallback(first=None, rest=None):
if __name__ == "__main__":
app.run(host='0.0.0.0')
app.run(host='0.0.0.0', debug=True)