remove text-generation-inference backend

This commit is contained in:
Cyberes 2023-09-12 13:09:47 -06:00
parent 6152b1bb66
commit 1d9f40765e
11 changed files with 1 additions and 159 deletions

View File

@ -22,7 +22,6 @@ config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middlewar
mode_ui_names = {
'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
'hf-textgen': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
'vllm': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
}

View File

@ -5,9 +5,6 @@ def generator(request_json_body):
if opts.mode == 'oobabooga':
from .oobabooga.generate import generate
return generate(request_json_body)
elif opts.mode == 'hf-textgen':
from .hf_textgen.generate import generate
return generate(request_json_body)
elif opts.mode == 'vllm':
from .vllm.generate import generate
r = generate(request_json_body)

View File

@ -1 +0,0 @@
# https://huggingface.github.io/text-generation-inference

View File

@ -1,51 +0,0 @@
"""
This file is used by the worker that processes requests.
"""
import requests
from llm_server import opts
def prepare_json(json_data: dict):
# token_count = len(tokenizer.encode(json_data.get('prompt', '')))
seed = json_data.get('seed', None)
if seed == -1:
seed = None
typical_p = json_data.get('typical_p', None)
if typical_p >= 1:
# https://github.com/huggingface/text-generation-inference/issues/929
typical_p = 0.998
return {
'inputs': json_data.get('prompt', ''),
'parameters': {
'max_new_tokens': min(json_data.get('max_new_tokens', opts.max_new_tokens), opts.max_new_tokens),
'repetition_penalty': json_data.get('repetition_penalty', None),
'seed': seed,
'stop': json_data.get('stopping_strings', []),
'temperature': json_data.get('temperature', None),
'top_k': json_data.get('top_k', None),
'top_p': json_data.get('top_p', None),
# 'truncate': opts.token_limit,
'typical_p': typical_p,
'watermark': False,
'do_sample': json_data.get('do_sample', False),
'return_full_text': False,
'details': True,
}
}
def generate(json_data: dict):
assert json_data.get('typical_p', 0) < 0.999
try:
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl)
except Exception as e:
return False, None, f'{e.__class__.__name__}: {e}'
return True, r, None
# except Exception as e:
# return False, None, f'{e.__class__.__name__}: {e}'
# if r.status_code != 200:
# return False, r, f'Backend returned {r.status_code}'
# return True, r, None

View File

@ -1,59 +0,0 @@
from typing import Tuple
import requests
from flask import jsonify
from llm_server import opts
from llm_server.database import log_prompt
from llm_server.helpers import indefinite_article
from llm_server.llm.llm_backend import LLMBackend
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.helpers.http import validate_json
class HfTextgenLLMBackend(LLMBackend):
def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
response_valid_json, response_json_body = validate_json(response)
backend_err = False
try:
response_status_code = response.status_code
except:
response_status_code = 0
if response_valid_json:
backend_response = response_json_body.get('generated_text', '')
if response_json_body.get('error'):
backend_err = True
error_type = response_json_body.get('error_type')
error_type_string = f'returned {indefinite_article(error_type)} {error_type} error'
backend_response = format_sillytavern_err(
f'Backend (hf-textgen) {error_type_string}: {response_json_body.get("error")}',
f'HTTP CODE {response_status_code}'
)
log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
return jsonify({
'results': [{'text': backend_response}]
}), 200
else:
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, is_error=True)
return jsonify({
'code': 500,
'msg': 'the backend did not return valid JSON',
'results': [{'text': backend_response}]
}), 200
def validate_params(self, params_dict: dict):
if params_dict.get('typical_p', 0) > 0.998:
return False, '`typical_p` must be less than 0.999'
return True, None
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# try:
# backend_response = requests.get(f'{opts.backend_url}/info', verify=opts.verify_ssl)
# r_json = backend_response.json()
# return r_json['model_id'].replace('/', '_'), None
# except Exception as e:
# return False, e

View File

@ -1,17 +0,0 @@
"""
Extra info that is added to the home page.
"""
hf_textget_info = """<p><strong>Important:</strong> This endpoint is running <a href="https://github.com/huggingface/text-generation-inference" target="_blank">text-generation-inference</a> and not all Oobabooga parameters are supported.</p>
<strong>Supported Parameters:</strong>
<ul>
<li><kbd>do_sample</kbd></li>
<li><kbd>max_new_tokens</kbd></li>
<li><kbd>repetition_penalty</kbd></li>
<li><kbd>seed</kbd></li>
<li><kbd>temperature</kbd></li>
<li><kbd>stopping_strings</kbd></li>
<li><kbd>top_k</kbd></li>
<li><kbd>top_p</kbd></li>
<li><kbd>typical_p</kbd></li>
</ul>"""

View File

@ -14,13 +14,6 @@ def get_running_model():
return r_json['result'], None
except Exception as e:
return False, e
elif opts.mode == 'hf-textgen':
try:
backend_response = requests.get(f'{opts.backend_url}/info', verify=opts.verify_ssl)
r_json = backend_response.json()
return r_json['model_id'].replace('/', '_'), None
except Exception as e:
return False, e
elif opts.mode == 'vllm':
try:
backend_response = requests.get(f'{opts.backend_url}/model', timeout=3, verify=opts.verify_ssl)

View File

@ -70,8 +70,6 @@ class OobaRequestHandler:
def get_backend(self):
if opts.mode == 'oobabooga':
return OobaboogaLLMBackend()
elif opts.mode == 'hf-textgen':
return HfTextgenLLMBackend()
elif opts.mode == 'vllm':
return VLLMBackend()
else:

View File

@ -104,10 +104,4 @@ def generate_stats():
},
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
}
# if opts.mode in ['oobabooga', 'hf-textgen']:
# output['endpoints']['streaming'] = f'wss://{opts.base_client_api}/v1/stream'
# else:
# output['endpoints']['streaming'] = None
return deep_sort(output)

View File

@ -42,17 +42,6 @@ class MainBackgroundThread(Thread):
else:
opts.running_model = model
redis.set('backend_online', 1)
elif opts.mode == 'hf-textgen':
try:
r = requests.get(f'{opts.backend_url}/info', timeout=3, verify=opts.verify_ssl)
j = r.json()
opts.running_model = j['model_id'].replace('/', '_')
redis.set('backend_online', 1)
redis.set_dict('backend_info', j)
except Exception as e:
redis.set('backend_online', 0)
# TODO: handle error
print(e)
elif opts.mode == 'vllm':
model, err = get_running_model()
if err:

View File

@ -50,7 +50,7 @@ if config['database_path'].startswith('./'):
opts.database_path = resolve_path(config['database_path'])
init_db()
if config['mode'] not in ['oobabooga', 'hf-textgen', 'vllm']:
if config['mode'] not in ['oobabooga', 'vllm']:
print('Unknown mode:', config['mode'])
sys.exit(1)
opts.mode = config['mode']