From 1d9f40765e12b38ea5dcd335655f67388524a804 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Tue, 12 Sep 2023 13:09:47 -0600 Subject: [PATCH] remove text-generation-inference backend --- llm_server/config.py | 1 - llm_server/llm/generator.py | 3 - llm_server/llm/hf_textgen/__init__.py | 1 - llm_server/llm/hf_textgen/generate.py | 51 ---------------- .../llm/hf_textgen/hf_textgen_backend.py | 59 ------------------- llm_server/llm/hf_textgen/info.py | 17 ------ llm_server/llm/info.py | 7 --- llm_server/routes/request_handler.py | 2 - llm_server/routes/v1/generate_stats.py | 6 -- llm_server/threads.py | 11 ---- server.py | 2 +- 11 files changed, 1 insertion(+), 159 deletions(-) delete mode 100644 llm_server/llm/hf_textgen/__init__.py delete mode 100644 llm_server/llm/hf_textgen/generate.py delete mode 100644 llm_server/llm/hf_textgen/hf_textgen_backend.py delete mode 100644 llm_server/llm/hf_textgen/info.py diff --git a/llm_server/config.py b/llm_server/config.py index 4e610a3..ba64952 100644 --- a/llm_server/config.py +++ b/llm_server/config.py @@ -22,7 +22,6 @@ config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middlewar mode_ui_names = { 'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), - 'hf-textgen': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), 'vllm': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), } diff --git a/llm_server/llm/generator.py b/llm_server/llm/generator.py index 55e2cc4..3aca56b 100644 --- a/llm_server/llm/generator.py +++ b/llm_server/llm/generator.py @@ -5,9 +5,6 @@ def generator(request_json_body): if opts.mode == 'oobabooga': from .oobabooga.generate import generate return generate(request_json_body) - elif opts.mode == 'hf-textgen': - from .hf_textgen.generate import generate - return generate(request_json_body) elif opts.mode == 'vllm': from .vllm.generate import generate r = generate(request_json_body) diff --git a/llm_server/llm/hf_textgen/__init__.py b/llm_server/llm/hf_textgen/__init__.py deleted file mode 100644 index c1e8529..0000000 --- a/llm_server/llm/hf_textgen/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# https://huggingface.github.io/text-generation-inference diff --git a/llm_server/llm/hf_textgen/generate.py b/llm_server/llm/hf_textgen/generate.py deleted file mode 100644 index 648e2fc..0000000 --- a/llm_server/llm/hf_textgen/generate.py +++ /dev/null @@ -1,51 +0,0 @@ -""" -This file is used by the worker that processes requests. -""" - -import requests - -from llm_server import opts - - -def prepare_json(json_data: dict): - # token_count = len(tokenizer.encode(json_data.get('prompt', ''))) - seed = json_data.get('seed', None) - if seed == -1: - seed = None - typical_p = json_data.get('typical_p', None) - if typical_p >= 1: - # https://github.com/huggingface/text-generation-inference/issues/929 - typical_p = 0.998 - return { - 'inputs': json_data.get('prompt', ''), - 'parameters': { - 'max_new_tokens': min(json_data.get('max_new_tokens', opts.max_new_tokens), opts.max_new_tokens), - 'repetition_penalty': json_data.get('repetition_penalty', None), - 'seed': seed, - 'stop': json_data.get('stopping_strings', []), - 'temperature': json_data.get('temperature', None), - 'top_k': json_data.get('top_k', None), - 'top_p': json_data.get('top_p', None), - # 'truncate': opts.token_limit, - 'typical_p': typical_p, - 'watermark': False, - 'do_sample': json_data.get('do_sample', False), - 'return_full_text': False, - 'details': True, - } - } - - -def generate(json_data: dict): - assert json_data.get('typical_p', 0) < 0.999 - try: - r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl) - except Exception as e: - return False, None, f'{e.__class__.__name__}: {e}' - return True, r, None - -# except Exception as e: -# return False, None, f'{e.__class__.__name__}: {e}' -# if r.status_code != 200: -# return False, r, f'Backend returned {r.status_code}' -# return True, r, None diff --git a/llm_server/llm/hf_textgen/hf_textgen_backend.py b/llm_server/llm/hf_textgen/hf_textgen_backend.py deleted file mode 100644 index c1b3513..0000000 --- a/llm_server/llm/hf_textgen/hf_textgen_backend.py +++ /dev/null @@ -1,59 +0,0 @@ -from typing import Tuple - -import requests -from flask import jsonify - -from llm_server import opts -from llm_server.database import log_prompt -from llm_server.helpers import indefinite_article -from llm_server.llm.llm_backend import LLMBackend -from llm_server.routes.helpers.client import format_sillytavern_err -from llm_server.routes.helpers.http import validate_json - - -class HfTextgenLLMBackend(LLMBackend): - def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers): - response_valid_json, response_json_body = validate_json(response) - backend_err = False - try: - response_status_code = response.status_code - except: - response_status_code = 0 - - if response_valid_json: - backend_response = response_json_body.get('generated_text', '') - - if response_json_body.get('error'): - backend_err = True - error_type = response_json_body.get('error_type') - error_type_string = f'returned {indefinite_article(error_type)} {error_type} error' - backend_response = format_sillytavern_err( - f'Backend (hf-textgen) {error_type_string}: {response_json_body.get("error")}', - f'HTTP CODE {response_status_code}' - ) - - log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err) - return jsonify({ - 'results': [{'text': backend_response}] - }), 200 - else: - backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error') - log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, is_error=True) - return jsonify({ - 'code': 500, - 'msg': 'the backend did not return valid JSON', - 'results': [{'text': backend_response}] - }), 200 - - def validate_params(self, params_dict: dict): - if params_dict.get('typical_p', 0) > 0.998: - return False, '`typical_p` must be less than 0.999' - return True, None - - # def get_model_info(self) -> Tuple[dict | bool, Exception | None]: - # try: - # backend_response = requests.get(f'{opts.backend_url}/info', verify=opts.verify_ssl) - # r_json = backend_response.json() - # return r_json['model_id'].replace('/', '_'), None - # except Exception as e: - # return False, e diff --git a/llm_server/llm/hf_textgen/info.py b/llm_server/llm/hf_textgen/info.py deleted file mode 100644 index a766000..0000000 --- a/llm_server/llm/hf_textgen/info.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Extra info that is added to the home page. -""" - -hf_textget_info = """

Important: This endpoint is running text-generation-inference and not all Oobabooga parameters are supported.

-Supported Parameters: -""" diff --git a/llm_server/llm/info.py b/llm_server/llm/info.py index 4121d3e..0718dfd 100644 --- a/llm_server/llm/info.py +++ b/llm_server/llm/info.py @@ -14,13 +14,6 @@ def get_running_model(): return r_json['result'], None except Exception as e: return False, e - elif opts.mode == 'hf-textgen': - try: - backend_response = requests.get(f'{opts.backend_url}/info', verify=opts.verify_ssl) - r_json = backend_response.json() - return r_json['model_id'].replace('/', '_'), None - except Exception as e: - return False, e elif opts.mode == 'vllm': try: backend_response = requests.get(f'{opts.backend_url}/model', timeout=3, verify=opts.verify_ssl) diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index da3c238..dd29ee8 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -70,8 +70,6 @@ class OobaRequestHandler: def get_backend(self): if opts.mode == 'oobabooga': return OobaboogaLLMBackend() - elif opts.mode == 'hf-textgen': - return HfTextgenLLMBackend() elif opts.mode == 'vllm': return VLLMBackend() else: diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index 0eea965..2959baf 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -104,10 +104,4 @@ def generate_stats(): }, 'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None, } - - # if opts.mode in ['oobabooga', 'hf-textgen']: - # output['endpoints']['streaming'] = f'wss://{opts.base_client_api}/v1/stream' - # else: - # output['endpoints']['streaming'] = None - return deep_sort(output) diff --git a/llm_server/threads.py b/llm_server/threads.py index 5ce9c0e..a5a6162 100644 --- a/llm_server/threads.py +++ b/llm_server/threads.py @@ -42,17 +42,6 @@ class MainBackgroundThread(Thread): else: opts.running_model = model redis.set('backend_online', 1) - elif opts.mode == 'hf-textgen': - try: - r = requests.get(f'{opts.backend_url}/info', timeout=3, verify=opts.verify_ssl) - j = r.json() - opts.running_model = j['model_id'].replace('/', '_') - redis.set('backend_online', 1) - redis.set_dict('backend_info', j) - except Exception as e: - redis.set('backend_online', 0) - # TODO: handle error - print(e) elif opts.mode == 'vllm': model, err = get_running_model() if err: diff --git a/server.py b/server.py index e50e45a..f6b9f42 100644 --- a/server.py +++ b/server.py @@ -50,7 +50,7 @@ if config['database_path'].startswith('./'): opts.database_path = resolve_path(config['database_path']) init_db() -if config['mode'] not in ['oobabooga', 'hf-textgen', 'vllm']: +if config['mode'] not in ['oobabooga', 'vllm']: print('Unknown mode:', config['mode']) sys.exit(1) opts.mode = config['mode']