prototype hf-textgen and adjust logging
This commit is contained in:
parent
a59dcea2da
commit
0d32db2dbd
|
@ -5,6 +5,7 @@ log_prompts: true
|
||||||
mode: oobabooga
|
mode: oobabooga
|
||||||
auth_required: false
|
auth_required: false
|
||||||
concurrent_gens: 3
|
concurrent_gens: 3
|
||||||
|
token_limit: 5555
|
||||||
|
|
||||||
backend_url: http://172.0.0.2:9104
|
backend_url: http://172.0.0.2:9104
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ def init_db(db_path):
|
||||||
prompt_tokens INTEGER,
|
prompt_tokens INTEGER,
|
||||||
response TEXT,
|
response TEXT,
|
||||||
response_tokens INTEGER,
|
response_tokens INTEGER,
|
||||||
|
response_status INTEGER,
|
||||||
parameters TEXT CHECK (parameters IS NULL OR json_valid(parameters)),
|
parameters TEXT CHECK (parameters IS NULL OR json_valid(parameters)),
|
||||||
headers TEXT CHECK (headers IS NULL OR json_valid(headers)),
|
headers TEXT CHECK (headers IS NULL OR json_valid(headers)),
|
||||||
timestamp INTEGER
|
timestamp INTEGER
|
||||||
|
@ -31,11 +32,15 @@ def init_db(db_path):
|
||||||
CREATE TABLE token_auth
|
CREATE TABLE token_auth
|
||||||
(token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER, disabled BOOLEAN default 0)
|
(token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER, disabled BOOLEAN default 0)
|
||||||
''')
|
''')
|
||||||
|
# c.execute('''
|
||||||
|
# CREATE TABLE leeches
|
||||||
|
# (url TEXT, online TEXT)
|
||||||
|
# ''')
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
def log_prompt(db_path, ip, token, prompt, response, parameters, headers):
|
def log_prompt(db_path, ip, token, prompt, response, parameters, headers, backend_response_code):
|
||||||
prompt_tokens = len(tokenizer.encode(prompt))
|
prompt_tokens = len(tokenizer.encode(prompt))
|
||||||
response_tokens = len(tokenizer.encode(response))
|
response_tokens = len(tokenizer.encode(response))
|
||||||
|
|
||||||
|
@ -45,8 +50,8 @@ def log_prompt(db_path, ip, token, prompt, response, parameters, headers):
|
||||||
timestamp = int(time.time())
|
timestamp = int(time.time())
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
(ip, token, prompt, prompt_tokens, response, response_tokens, json.dumps(parameters), json.dumps(headers), timestamp))
|
(ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
|
@ -1,36 +1,44 @@
|
||||||
|
import json
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
|
from llm_server.database import tokenizer
|
||||||
|
|
||||||
|
|
||||||
def prepare_json(json_data: dict):
|
def prepare_json(json_data: dict):
|
||||||
token_count = len(current_app.tokenizer.encode(json_data.get('prompt', '')))
|
token_count = len(tokenizer.encode(json_data.get('prompt', '')))
|
||||||
seed = json_data.get('seed', None)
|
seed = json_data.get('seed', None)
|
||||||
if seed == -1:
|
if seed == -1:
|
||||||
seed = None
|
seed = None
|
||||||
|
typical_p = json_data.get('typical_p', None)
|
||||||
|
if typical_p >= 1:
|
||||||
|
typical_p = 0.999
|
||||||
return {
|
return {
|
||||||
'inputs': json_data.get('prompt', ''),
|
'inputs': json_data.get('prompt', ''),
|
||||||
'parameters': {
|
'parameters': {
|
||||||
'max_new_tokens': token_count - opts.token_limit,
|
'max_new_tokens': opts.token_limit - token_count,
|
||||||
'repetition_penalty': json_data.get('repetition_penalty', None),
|
'repetition_penalty': json_data.get('repetition_penalty', None),
|
||||||
'seed': seed,
|
'seed': seed,
|
||||||
'stop': json_data.get('stopping_strings', []),
|
'stop': json_data.get('stopping_strings', []),
|
||||||
'temperature': json_data.get('temperature', None),
|
'temperature': json_data.get('temperature', None),
|
||||||
'top_k': json_data.get('top_k', None),
|
'top_k': json_data.get('top_k', None),
|
||||||
'top_p': json_data.get('top_p', None),
|
'top_p': json_data.get('top_p', None),
|
||||||
'truncate': True,
|
# 'truncate': opts.token_limit,
|
||||||
'typical_p': json_data.get('typical_p', None),
|
'typical_p': typical_p,
|
||||||
'watermark': False
|
'watermark': False
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def generate(json_data: dict):
|
def generate(json_data: dict):
|
||||||
try:
|
print(json.dumps(prepare_json(json_data)))
|
||||||
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data))
|
# try:
|
||||||
except Exception as e:
|
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data))
|
||||||
return False, None, f'{e.__class__.__name__}: {e}'
|
print(r.text)
|
||||||
if r.status_code != 200:
|
# except Exception as e:
|
||||||
return False, r, f'Backend returned {r.status_code}'
|
# return False, None, f'{e.__class__.__name__}: {e}'
|
||||||
return True, r, None
|
# if r.status_code != 200:
|
||||||
|
# return False, r, f'Backend returned {r.status_code}'
|
||||||
|
# return True, r, None
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from llm_server import opts
|
||||||
|
|
||||||
|
|
||||||
|
def get_running_model():
|
||||||
|
if opts.mode == 'oobabooga':
|
||||||
|
try:
|
||||||
|
backend_response = requests.get(f'{opts.backend_url}/api/v1/model')
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
r_json = backend_response.json()
|
||||||
|
return r_json['result']
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
|
elif opts.mode == 'hf-textgen':
|
||||||
|
try:
|
||||||
|
backend_response = requests.get(f'{opts.backend_url}/info')
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
r_json = backend_response.json()
|
||||||
|
return r_json['model_id'].replace('/', '_')
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
raise Exception
|
|
@ -1,15 +1,3 @@
|
||||||
import requests
|
|
||||||
|
|
||||||
from llm_server import opts
|
|
||||||
|
|
||||||
|
|
||||||
def get_running_model():
|
|
||||||
try:
|
|
||||||
backend_response = requests.get(f'{opts.backend_url}/api/v1/model')
|
|
||||||
except Exception as e:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
r_json = backend_response.json()
|
|
||||||
return r_json['result']
|
|
||||||
except Exception as e:
|
|
||||||
return False
|
|
||||||
|
|
|
@ -6,14 +6,16 @@ from ..helpers.http import cache_control, validate_json
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...database import log_prompt
|
from ...database import log_prompt
|
||||||
|
|
||||||
if opts.mode == 'oobabooga':
|
|
||||||
from ...llm.oobabooga.generate import generate
|
|
||||||
|
|
||||||
generator = generate
|
def generator(request_json_body):
|
||||||
elif opts.mode == 'hf-textgen':
|
if opts.mode == 'oobabooga':
|
||||||
from ...llm.hf_textgen.generate import generate
|
from ...llm.oobabooga.generate import generate
|
||||||
|
return generate(request_json_body)
|
||||||
generator = generate
|
elif opts.mode == 'hf-textgen':
|
||||||
|
from ...llm.hf_textgen.generate import generate
|
||||||
|
return generate(request_json_body)
|
||||||
|
else:
|
||||||
|
raise Exception
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/generate', methods=['POST'])
|
@bp.route('/generate', methods=['POST'])
|
||||||
|
@ -49,7 +51,7 @@ def generate():
|
||||||
|
|
||||||
token = request.headers.get('X-Api-Key')
|
token = request.headers.get('X-Api-Key')
|
||||||
|
|
||||||
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], response_json_body['results'][0]['text'], parameters, dict(request.headers))
|
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], response_json_body['results'][0]['text'], parameters, dict(request.headers), response.status_code)
|
||||||
return jsonify({
|
return jsonify({
|
||||||
**response_json_body
|
**response_json_body
|
||||||
}), 200
|
}), 200
|
||||||
|
|
|
@ -4,7 +4,8 @@ from flask import jsonify
|
||||||
|
|
||||||
from . import bp
|
from . import bp
|
||||||
from ..helpers.http import cache_control
|
from ..helpers.http import cache_control
|
||||||
from ...llm.oobabooga.info import get_running_model
|
from ... import opts
|
||||||
|
from ...llm.info import get_running_model
|
||||||
from ..cache import cache
|
from ..cache import cache
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from .. import stats
|
||||||
from ..cache import cache
|
from ..cache import cache
|
||||||
from ..helpers.http import cache_control
|
from ..helpers.http import cache_control
|
||||||
from ..stats import proompters_1_min
|
from ..stats import proompters_1_min
|
||||||
from ...llm.oobabooga.info import get_running_model
|
from ...llm.info import get_running_model
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/stats', methods=['GET'])
|
@bp.route('/stats', methods=['GET'])
|
||||||
|
|
|
@ -9,7 +9,7 @@ from llm_server import opts
|
||||||
from llm_server.config import ConfigLoader
|
from llm_server.config import ConfigLoader
|
||||||
from llm_server.database import init_db
|
from llm_server.database import init_db
|
||||||
from llm_server.helpers import resolve_path
|
from llm_server.helpers import resolve_path
|
||||||
from llm_server.llm.oobabooga.info import get_running_model
|
from llm_server.llm.info import get_running_model
|
||||||
from llm_server.routes.cache import cache
|
from llm_server.routes.cache import cache
|
||||||
from llm_server.routes.helpers.http import cache_control
|
from llm_server.routes.helpers.http import cache_control
|
||||||
from llm_server.routes.v1 import bp
|
from llm_server.routes.v1 import bp
|
||||||
|
@ -23,7 +23,7 @@ else:
|
||||||
config_path = Path(script_path, 'config', 'config.yml')
|
config_path = Path(script_path, 'config', 'config.yml')
|
||||||
|
|
||||||
default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False, 'concurrent_gens': 3, 'frontend_api_client': ''}
|
default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False, 'concurrent_gens': 3, 'frontend_api_client': ''}
|
||||||
required_vars = []
|
required_vars = ['token_limit']
|
||||||
config_loader = ConfigLoader(config_path, default_vars, required_vars)
|
config_loader = ConfigLoader(config_path, default_vars, required_vars)
|
||||||
success, config, msg = config_loader.load_config()
|
success, config, msg = config_loader.load_config()
|
||||||
if not success:
|
if not success:
|
||||||
|
@ -46,9 +46,11 @@ opts.auth_required = config['auth_required']
|
||||||
opts.log_prompts = config['log_prompts']
|
opts.log_prompts = config['log_prompts']
|
||||||
opts.concurrent_gens = config['concurrent_gens']
|
opts.concurrent_gens = config['concurrent_gens']
|
||||||
opts.frontend_api_client = config['frontend_api_client']
|
opts.frontend_api_client = config['frontend_api_client']
|
||||||
|
opts.token_limit = config['token_limit']
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
cache.init_app(app)
|
cache.init_app(app)
|
||||||
|
cache.clear() # clear redis cache
|
||||||
# with app.app_context():
|
# with app.app_context():
|
||||||
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
|
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||||
app.register_blueprint(bp, url_prefix='/api/v1/')
|
app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||||
|
@ -69,4 +71,4 @@ def fallback(first=None, rest=None):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app.run(host='0.0.0.0')
|
app.run(host='0.0.0.0', debug=True)
|
||||||
|
|
Reference in New Issue