diff --git a/config/config.yml b/config/config.yml index 87c3d50..c0c3083 100644 --- a/config/config.yml +++ b/config/config.yml @@ -5,6 +5,7 @@ log_prompts: true mode: oobabooga auth_required: false concurrent_gens: 3 +token_limit: 5555 backend_url: http://172.0.0.2:9104 diff --git a/llm_server/database.py b/llm_server/database.py index a367679..afa9519 100644 --- a/llm_server/database.py +++ b/llm_server/database.py @@ -22,6 +22,7 @@ def init_db(db_path): prompt_tokens INTEGER, response TEXT, response_tokens INTEGER, + response_status INTEGER, parameters TEXT CHECK (parameters IS NULL OR json_valid(parameters)), headers TEXT CHECK (headers IS NULL OR json_valid(headers)), timestamp INTEGER @@ -31,11 +32,15 @@ def init_db(db_path): CREATE TABLE token_auth (token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER, disabled BOOLEAN default 0) ''') + # c.execute(''' + # CREATE TABLE leeches + # (url TEXT, online TEXT) + # ''') conn.commit() conn.close() -def log_prompt(db_path, ip, token, prompt, response, parameters, headers): +def log_prompt(db_path, ip, token, prompt, response, parameters, headers, backend_response_code): prompt_tokens = len(tokenizer.encode(prompt)) response_tokens = len(tokenizer.encode(response)) @@ -45,8 +50,8 @@ def log_prompt(db_path, ip, token, prompt, response, parameters, headers): timestamp = int(time.time()) conn = sqlite3.connect(db_path) c = conn.cursor() - c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", - (ip, token, prompt, prompt_tokens, response, response_tokens, json.dumps(parameters), json.dumps(headers), timestamp)) + c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + (ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) conn.commit() conn.close() diff --git a/llm_server/llm/hf_textgen/generate.py b/llm_server/llm/hf_textgen/generate.py index 5093938..9b7dfdc 100644 --- a/llm_server/llm/hf_textgen/generate.py +++ b/llm_server/llm/hf_textgen/generate.py @@ -1,36 +1,44 @@ +import json + import requests from flask import current_app from llm_server import opts +from llm_server.database import tokenizer def prepare_json(json_data: dict): - token_count = len(current_app.tokenizer.encode(json_data.get('prompt', ''))) + token_count = len(tokenizer.encode(json_data.get('prompt', ''))) seed = json_data.get('seed', None) if seed == -1: seed = None + typical_p = json_data.get('typical_p', None) + if typical_p >= 1: + typical_p = 0.999 return { 'inputs': json_data.get('prompt', ''), 'parameters': { - 'max_new_tokens': token_count - opts.token_limit, + 'max_new_tokens': opts.token_limit - token_count, 'repetition_penalty': json_data.get('repetition_penalty', None), 'seed': seed, 'stop': json_data.get('stopping_strings', []), 'temperature': json_data.get('temperature', None), 'top_k': json_data.get('top_k', None), 'top_p': json_data.get('top_p', None), - 'truncate': True, - 'typical_p': json_data.get('typical_p', None), + # 'truncate': opts.token_limit, + 'typical_p': typical_p, 'watermark': False } } def generate(json_data: dict): - try: - r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data)) - except Exception as e: - return False, None, f'{e.__class__.__name__}: {e}' - if r.status_code != 200: - return False, r, f'Backend returned {r.status_code}' - return True, r, None + print(json.dumps(prepare_json(json_data))) + # try: + r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data)) + print(r.text) + # except Exception as e: + # return False, None, f'{e.__class__.__name__}: {e}' + # if r.status_code != 200: + # return False, r, f'Backend returned {r.status_code}' + # return True, r, None diff --git a/llm_server/llm/info.py b/llm_server/llm/info.py new file mode 100644 index 0000000..907e10e --- /dev/null +++ b/llm_server/llm/info.py @@ -0,0 +1,28 @@ +import requests + +from llm_server import opts + + +def get_running_model(): + if opts.mode == 'oobabooga': + try: + backend_response = requests.get(f'{opts.backend_url}/api/v1/model') + except Exception as e: + return False + try: + r_json = backend_response.json() + return r_json['result'] + except Exception as e: + return False + elif opts.mode == 'hf-textgen': + try: + backend_response = requests.get(f'{opts.backend_url}/info') + except Exception as e: + return False + try: + r_json = backend_response.json() + return r_json['model_id'].replace('/', '_') + except Exception as e: + return False + else: + raise Exception diff --git a/llm_server/llm/oobabooga/info.py b/llm_server/llm/oobabooga/info.py index e27b2fd..b28b04f 100644 --- a/llm_server/llm/oobabooga/info.py +++ b/llm_server/llm/oobabooga/info.py @@ -1,15 +1,3 @@ -import requests - -from llm_server import opts -def get_running_model(): - try: - backend_response = requests.get(f'{opts.backend_url}/api/v1/model') - except Exception as e: - return False - try: - r_json = backend_response.json() - return r_json['result'] - except Exception as e: - return False \ No newline at end of file + diff --git a/llm_server/llm/tokenizer.py b/llm_server/llm/tokenizer.py deleted file mode 100644 index e69de29..0000000 diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py index c283c33..36fac22 100644 --- a/llm_server/routes/v1/generate.py +++ b/llm_server/routes/v1/generate.py @@ -6,14 +6,16 @@ from ..helpers.http import cache_control, validate_json from ... import opts from ...database import log_prompt -if opts.mode == 'oobabooga': - from ...llm.oobabooga.generate import generate - generator = generate -elif opts.mode == 'hf-textgen': - from ...llm.hf_textgen.generate import generate - - generator = generate +def generator(request_json_body): + if opts.mode == 'oobabooga': + from ...llm.oobabooga.generate import generate + return generate(request_json_body) + elif opts.mode == 'hf-textgen': + from ...llm.hf_textgen.generate import generate + return generate(request_json_body) + else: + raise Exception @bp.route('/generate', methods=['POST']) @@ -49,7 +51,7 @@ def generate(): token = request.headers.get('X-Api-Key') - log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], response_json_body['results'][0]['text'], parameters, dict(request.headers)) + log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], response_json_body['results'][0]['text'], parameters, dict(request.headers), response.status_code) return jsonify({ **response_json_body }), 200 diff --git a/llm_server/routes/v1/info.py b/llm_server/routes/v1/info.py index 45970c4..e4774e0 100644 --- a/llm_server/routes/v1/info.py +++ b/llm_server/routes/v1/info.py @@ -4,7 +4,8 @@ from flask import jsonify from . import bp from ..helpers.http import cache_control -from ...llm.oobabooga.info import get_running_model +from ... import opts +from ...llm.info import get_running_model from ..cache import cache diff --git a/llm_server/routes/v1/proxy.py b/llm_server/routes/v1/proxy.py index 83a7a70..ec2b8f7 100644 --- a/llm_server/routes/v1/proxy.py +++ b/llm_server/routes/v1/proxy.py @@ -10,7 +10,7 @@ from .. import stats from ..cache import cache from ..helpers.http import cache_control from ..stats import proompters_1_min -from ...llm.oobabooga.info import get_running_model +from ...llm.info import get_running_model @bp.route('/stats', methods=['GET']) diff --git a/server.py b/server.py index 833d8df..820d6af 100644 --- a/server.py +++ b/server.py @@ -9,7 +9,7 @@ from llm_server import opts from llm_server.config import ConfigLoader from llm_server.database import init_db from llm_server.helpers import resolve_path -from llm_server.llm.oobabooga.info import get_running_model +from llm_server.llm.info import get_running_model from llm_server.routes.cache import cache from llm_server.routes.helpers.http import cache_control from llm_server.routes.v1 import bp @@ -23,7 +23,7 @@ else: config_path = Path(script_path, 'config', 'config.yml') default_vars = {'mode': 'oobabooga', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False, 'concurrent_gens': 3, 'frontend_api_client': ''} -required_vars = [] +required_vars = ['token_limit'] config_loader = ConfigLoader(config_path, default_vars, required_vars) success, config, msg = config_loader.load_config() if not success: @@ -46,9 +46,11 @@ opts.auth_required = config['auth_required'] opts.log_prompts = config['log_prompts'] opts.concurrent_gens = config['concurrent_gens'] opts.frontend_api_client = config['frontend_api_client'] +opts.token_limit = config['token_limit'] app = Flask(__name__) cache.init_app(app) +cache.clear() # clear redis cache # with app.app_context(): # current_app.tokenizer = tiktoken.get_encoding("cl100k_base") app.register_blueprint(bp, url_prefix='/api/v1/') @@ -69,4 +71,4 @@ def fallback(first=None, rest=None): if __name__ == "__main__": - app.run(host='0.0.0.0') + app.run(host='0.0.0.0', debug=True)