remove text-generation-inference backend
This commit is contained in:
parent
6152b1bb66
commit
1d9f40765e
|
@ -22,7 +22,6 @@ config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middlewar
|
|||
|
||||
mode_ui_names = {
|
||||
'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
|
||||
'hf-textgen': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
|
||||
'vllm': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
|
||||
}
|
||||
|
||||
|
|
|
@ -5,9 +5,6 @@ def generator(request_json_body):
|
|||
if opts.mode == 'oobabooga':
|
||||
from .oobabooga.generate import generate
|
||||
return generate(request_json_body)
|
||||
elif opts.mode == 'hf-textgen':
|
||||
from .hf_textgen.generate import generate
|
||||
return generate(request_json_body)
|
||||
elif opts.mode == 'vllm':
|
||||
from .vllm.generate import generate
|
||||
r = generate(request_json_body)
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
# https://huggingface.github.io/text-generation-inference
|
|
@ -1,51 +0,0 @@
|
|||
"""
|
||||
This file is used by the worker that processes requests.
|
||||
"""
|
||||
|
||||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
|
||||
def prepare_json(json_data: dict):
|
||||
# token_count = len(tokenizer.encode(json_data.get('prompt', '')))
|
||||
seed = json_data.get('seed', None)
|
||||
if seed == -1:
|
||||
seed = None
|
||||
typical_p = json_data.get('typical_p', None)
|
||||
if typical_p >= 1:
|
||||
# https://github.com/huggingface/text-generation-inference/issues/929
|
||||
typical_p = 0.998
|
||||
return {
|
||||
'inputs': json_data.get('prompt', ''),
|
||||
'parameters': {
|
||||
'max_new_tokens': min(json_data.get('max_new_tokens', opts.max_new_tokens), opts.max_new_tokens),
|
||||
'repetition_penalty': json_data.get('repetition_penalty', None),
|
||||
'seed': seed,
|
||||
'stop': json_data.get('stopping_strings', []),
|
||||
'temperature': json_data.get('temperature', None),
|
||||
'top_k': json_data.get('top_k', None),
|
||||
'top_p': json_data.get('top_p', None),
|
||||
# 'truncate': opts.token_limit,
|
||||
'typical_p': typical_p,
|
||||
'watermark': False,
|
||||
'do_sample': json_data.get('do_sample', False),
|
||||
'return_full_text': False,
|
||||
'details': True,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def generate(json_data: dict):
|
||||
assert json_data.get('typical_p', 0) < 0.999
|
||||
try:
|
||||
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl)
|
||||
except Exception as e:
|
||||
return False, None, f'{e.__class__.__name__}: {e}'
|
||||
return True, r, None
|
||||
|
||||
# except Exception as e:
|
||||
# return False, None, f'{e.__class__.__name__}: {e}'
|
||||
# if r.status_code != 200:
|
||||
# return False, r, f'Backend returned {r.status_code}'
|
||||
# return True, r, None
|
|
@ -1,59 +0,0 @@
|
|||
from typing import Tuple
|
||||
|
||||
import requests
|
||||
from flask import jsonify
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import log_prompt
|
||||
from llm_server.helpers import indefinite_article
|
||||
from llm_server.llm.llm_backend import LLMBackend
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.helpers.http import validate_json
|
||||
|
||||
|
||||
class HfTextgenLLMBackend(LLMBackend):
|
||||
def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
|
||||
response_valid_json, response_json_body = validate_json(response)
|
||||
backend_err = False
|
||||
try:
|
||||
response_status_code = response.status_code
|
||||
except:
|
||||
response_status_code = 0
|
||||
|
||||
if response_valid_json:
|
||||
backend_response = response_json_body.get('generated_text', '')
|
||||
|
||||
if response_json_body.get('error'):
|
||||
backend_err = True
|
||||
error_type = response_json_body.get('error_type')
|
||||
error_type_string = f'returned {indefinite_article(error_type)} {error_type} error'
|
||||
backend_response = format_sillytavern_err(
|
||||
f'Backend (hf-textgen) {error_type_string}: {response_json_body.get("error")}',
|
||||
f'HTTP CODE {response_status_code}'
|
||||
)
|
||||
|
||||
log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
|
||||
return jsonify({
|
||||
'results': [{'text': backend_response}]
|
||||
}), 200
|
||||
else:
|
||||
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
|
||||
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, is_error=True)
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
'msg': 'the backend did not return valid JSON',
|
||||
'results': [{'text': backend_response}]
|
||||
}), 200
|
||||
|
||||
def validate_params(self, params_dict: dict):
|
||||
if params_dict.get('typical_p', 0) > 0.998:
|
||||
return False, '`typical_p` must be less than 0.999'
|
||||
return True, None
|
||||
|
||||
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
|
||||
# try:
|
||||
# backend_response = requests.get(f'{opts.backend_url}/info', verify=opts.verify_ssl)
|
||||
# r_json = backend_response.json()
|
||||
# return r_json['model_id'].replace('/', '_'), None
|
||||
# except Exception as e:
|
||||
# return False, e
|
|
@ -1,17 +0,0 @@
|
|||
"""
|
||||
Extra info that is added to the home page.
|
||||
"""
|
||||
|
||||
hf_textget_info = """<p><strong>Important:</strong> This endpoint is running <a href="https://github.com/huggingface/text-generation-inference" target="_blank">text-generation-inference</a> and not all Oobabooga parameters are supported.</p>
|
||||
<strong>Supported Parameters:</strong>
|
||||
<ul>
|
||||
<li><kbd>do_sample</kbd></li>
|
||||
<li><kbd>max_new_tokens</kbd></li>
|
||||
<li><kbd>repetition_penalty</kbd></li>
|
||||
<li><kbd>seed</kbd></li>
|
||||
<li><kbd>temperature</kbd></li>
|
||||
<li><kbd>stopping_strings</kbd></li>
|
||||
<li><kbd>top_k</kbd></li>
|
||||
<li><kbd>top_p</kbd></li>
|
||||
<li><kbd>typical_p</kbd></li>
|
||||
</ul>"""
|
|
@ -14,13 +14,6 @@ def get_running_model():
|
|||
return r_json['result'], None
|
||||
except Exception as e:
|
||||
return False, e
|
||||
elif opts.mode == 'hf-textgen':
|
||||
try:
|
||||
backend_response = requests.get(f'{opts.backend_url}/info', verify=opts.verify_ssl)
|
||||
r_json = backend_response.json()
|
||||
return r_json['model_id'].replace('/', '_'), None
|
||||
except Exception as e:
|
||||
return False, e
|
||||
elif opts.mode == 'vllm':
|
||||
try:
|
||||
backend_response = requests.get(f'{opts.backend_url}/model', timeout=3, verify=opts.verify_ssl)
|
||||
|
|
|
@ -70,8 +70,6 @@ class OobaRequestHandler:
|
|||
def get_backend(self):
|
||||
if opts.mode == 'oobabooga':
|
||||
return OobaboogaLLMBackend()
|
||||
elif opts.mode == 'hf-textgen':
|
||||
return HfTextgenLLMBackend()
|
||||
elif opts.mode == 'vllm':
|
||||
return VLLMBackend()
|
||||
else:
|
||||
|
|
|
@ -104,10 +104,4 @@ def generate_stats():
|
|||
},
|
||||
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
|
||||
}
|
||||
|
||||
# if opts.mode in ['oobabooga', 'hf-textgen']:
|
||||
# output['endpoints']['streaming'] = f'wss://{opts.base_client_api}/v1/stream'
|
||||
# else:
|
||||
# output['endpoints']['streaming'] = None
|
||||
|
||||
return deep_sort(output)
|
||||
|
|
|
@ -42,17 +42,6 @@ class MainBackgroundThread(Thread):
|
|||
else:
|
||||
opts.running_model = model
|
||||
redis.set('backend_online', 1)
|
||||
elif opts.mode == 'hf-textgen':
|
||||
try:
|
||||
r = requests.get(f'{opts.backend_url}/info', timeout=3, verify=opts.verify_ssl)
|
||||
j = r.json()
|
||||
opts.running_model = j['model_id'].replace('/', '_')
|
||||
redis.set('backend_online', 1)
|
||||
redis.set_dict('backend_info', j)
|
||||
except Exception as e:
|
||||
redis.set('backend_online', 0)
|
||||
# TODO: handle error
|
||||
print(e)
|
||||
elif opts.mode == 'vllm':
|
||||
model, err = get_running_model()
|
||||
if err:
|
||||
|
|
|
@ -50,7 +50,7 @@ if config['database_path'].startswith('./'):
|
|||
opts.database_path = resolve_path(config['database_path'])
|
||||
init_db()
|
||||
|
||||
if config['mode'] not in ['oobabooga', 'hf-textgen', 'vllm']:
|
||||
if config['mode'] not in ['oobabooga', 'vllm']:
|
||||
print('Unknown mode:', config['mode'])
|
||||
sys.exit(1)
|
||||
opts.mode = config['mode']
|
||||
|
|
Reference in New Issue