prototype hf-textgen and adjust logging
This commit is contained in:
parent
a59dcea2da
commit
0d32db2dbd
|
@ -5,6 +5,7 @@ log_prompts: true
|
|||
mode: oobabooga
|
||||
auth_required: false
|
||||
concurrent_gens: 3
|
||||
token_limit: 5555
|
||||
|
||||
backend_url: http://172.0.0.2:9104
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ def init_db(db_path):
|
|||
prompt_tokens INTEGER,
|
||||
response TEXT,
|
||||
response_tokens INTEGER,
|
||||
response_status INTEGER,
|
||||
parameters TEXT CHECK (parameters IS NULL OR json_valid(parameters)),
|
||||
headers TEXT CHECK (headers IS NULL OR json_valid(headers)),
|
||||
timestamp INTEGER
|
||||
|
@ -31,11 +32,15 @@ def init_db(db_path):
|
|||
CREATE TABLE token_auth
|
||||
(token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER, disabled BOOLEAN default 0)
|
||||
''')
|
||||
# c.execute('''
|
||||
# CREATE TABLE leeches
|
||||
# (url TEXT, online TEXT)
|
||||
# ''')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def log_prompt(db_path, ip, token, prompt, response, parameters, headers):
|
||||
def log_prompt(db_path, ip, token, prompt, response, parameters, headers, backend_response_code):
|
||||
prompt_tokens = len(tokenizer.encode(prompt))
|
||||
response_tokens = len(tokenizer.encode(response))
|
||||
|
||||
|
@ -45,8 +50,8 @@ def log_prompt(db_path, ip, token, prompt, response, parameters, headers):
|
|||
timestamp = int(time.time())
|
||||
conn = sqlite3.connect(db_path)
|
||||
c = conn.cursor()
|
||||
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(ip, token, prompt, prompt_tokens, response, response_tokens, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
|
|
@ -1,36 +1,44 @@
|
|||
import json
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import tokenizer
|
||||
|
||||
|
||||
def prepare_json(json_data: dict):
|
||||
token_count = len(current_app.tokenizer.encode(json_data.get('prompt', '')))
|
||||
token_count = len(tokenizer.encode(json_data.get('prompt', '')))
|
||||
seed = json_data.get('seed', None)
|
||||
if seed == -1:
|
||||
seed = None
|
||||
typical_p = json_data.get('typical_p', None)
|
||||
if typical_p >= 1:
|
||||
typical_p = 0.999
|
||||
return {
|
||||
'inputs': json_data.get('prompt', ''),
|
||||
'parameters': {
|
||||
'max_new_tokens': token_count - opts.token_limit,
|
||||
'max_new_tokens': opts.token_limit - token_count,
|
||||
'repetition_penalty': json_data.get('repetition_penalty', None),
|
||||
'seed': seed,
|
||||
'stop': json_data.get('stopping_strings', []),
|
||||
'temperature': json_data.get('temperature', None),
|
||||
'top_k': json_data.get('top_k', None),
|
||||
'top_p': json_data.get('top_p', None),
|
||||
'truncate': True,
|
||||
'typical_p': json_data.get('typical_p', None),
|
||||
# 'truncate': opts.token_limit,
|
||||
'typical_p': typical_p,
|
||||
'watermark': False
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def generate(json_data: dict):
|
||||
try:
|
||||
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data))
|
||||
except Exception as e:
|
||||
return False, None, f'{e.__class__.__name__}: {e}'
|
||||
if r.status_code != 200:
|
||||
return False, r, f'Backend returned {r.status_code}'
|
||||
return True, r, None
|
||||
print(json.dumps(prepare_json(json_data)))
|
||||
# try:
|
||||
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data))
|
||||
print(r.text)
|
||||
# except Exception as e:
|
||||
# return False, None, f'{e.__class__.__name__}: {e}'
|
||||
# if r.status_code != 200:
|
||||
# return False, r, f'Backend returned {r.status_code}'
|
||||
# return True, r, None
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
|
||||
def get_running_model():
|
||||
if opts.mode == 'oobabooga':
|
||||
try:
|
||||
backend_response = requests.get(f'{opts.backend_url}/api/v1/model')
|
||||
except Exception as e:
|
||||
return False
|
||||
try:
|
||||
r_json = backend_response.json()
|
||||
return r_json['result']
|
||||
except Exception as e:
|
||||
return False
|
||||
elif opts.mode == 'hf-textgen':
|
||||
try:
|
||||
backend_response = requests.get(f'{opts.backend_url}/info')
|
||||
except Exception as e:
|
||||
return False
|
||||
try:
|
||||
r_json = backend_response.json()
|
||||
return r_json['model_id'].replace('/', '_')
|
||||
except Exception as e:
|
||||
return False
|
||||
else:
|
||||
raise Exception
|
|
@ -1,15 +1,3 @@
|
|||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
|
||||
def get_running_model():
|
||||
try:
|
||||
backend_response = requests.get(f'{opts.backend_url}/api/v1/model')
|
||||
except Exception as e:
|
||||
return False
|
||||
try:
|
||||
r_json = backend_response.json()
|
||||
return r_json['result']
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
|
|
|
@ -6,14 +6,16 @@ from ..helpers.http import cache_control, validate_json
|
|||
from ... import opts
|
||||
from ...database import log_prompt
|
||||
|
||||
if opts.mode == 'oobabooga':
|
||||
from ...llm.oobabooga.generate import generate
|
||||
|
||||
generator = generate
|
||||
elif opts.mode == 'hf-textgen':
|
||||
from ...llm.hf_textgen.generate import generate
|
||||
|
||||
generator = generate
|
||||
def generator(request_json_body):
|
||||
if opts.mode == 'oobabooga':
|
||||
from ...llm.oobabooga.generate import generate
|
||||
return generate(request_json_body)
|
||||
elif opts.mode == 'hf-textgen':
|
||||
from ...llm.hf_textgen.generate import generate
|
||||
return generate(request_json_body)
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
|
||||
@bp.route('/generate', methods=['POST'])
|
||||
|
@ -49,7 +51,7 @@ def generate():
|
|||
|
||||
token = request.headers.get('X-Api-Key')
|
||||
|
||||
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], response_json_body['results'][0]['text'], parameters, dict(request.headers))
|
||||
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], response_json_body['results'][0]['text'], parameters, dict(request.headers), response.status_code)
|
||||
return jsonify({
|
||||
**response_json_body
|
||||
}), 200
|
||||
|
|
|
@ -4,7 +4,8 @@ from flask import jsonify
|
|||
|
||||
from . import bp
|
||||
from ..helpers.http import cache_control
|
||||
from ...llm.oobabooga.info import get_running_model
|
||||
from ... import opts
|
||||
from ...llm.info import get_running_model
|
||||
from ..cache import cache
|
||||
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from .. import stats
|
|||
from ..cache import cache
|
||||
from ..helpers.http import cache_control
|
||||
from ..stats import proompters_1_min
|
||||
from ...llm.oobabooga.info import get_running_model
|
||||
from ...llm.info import get_running_model
|
||||
|
||||
|
||||
@bp.route('/stats', methods=['GET'])
|
||||
|
|
|
@ -9,7 +9,7 @@ from llm_server import opts
|
|||
from llm_server.config import ConfigLoader
|
||||
from llm_server.database import init_db
|
||||
from llm_server.helpers import resolve_path
|
||||
from llm_server.llm.oobabooga.info import get_running_model
|
||||
from llm_server.llm.info import get_running_model
|
||||
from llm_server.routes.cache import cache
|
||||
from llm_server.routes.helpers.http import cache_control
|
||||
from llm_server.routes.v1 import bp
|
||||
|
@ -23,7 +23,7 @@ else:
|
|||
config_path = Path(script_path, 'config', 'config.yml')
|
||||
|
||||
default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False, 'concurrent_gens': 3, 'frontend_api_client': ''}
|
||||
required_vars = []
|
||||
required_vars = ['token_limit']
|
||||
config_loader = ConfigLoader(config_path, default_vars, required_vars)
|
||||
success, config, msg = config_loader.load_config()
|
||||
if not success:
|
||||
|
@ -46,9 +46,11 @@ opts.auth_required = config['auth_required']
|
|||
opts.log_prompts = config['log_prompts']
|
||||
opts.concurrent_gens = config['concurrent_gens']
|
||||
opts.frontend_api_client = config['frontend_api_client']
|
||||
opts.token_limit = config['token_limit']
|
||||
|
||||
app = Flask(__name__)
|
||||
cache.init_app(app)
|
||||
cache.clear() # clear redis cache
|
||||
# with app.app_context():
|
||||
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||
|
@ -69,4 +71,4 @@ def fallback(first=None, rest=None):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host='0.0.0.0')
|
||||
app.run(host='0.0.0.0', debug=True)
|
||||
|
|
Reference in New Issue