add HF text-generation-inference backend

This commit is contained in:
Cyberes 2023-08-29 13:46:41 -06:00
parent 6c0e60135d
commit ba0bc87434
11 changed files with 148 additions and 71 deletions

View File

@ -14,6 +14,7 @@ config_default_vars = {
'info_html': None,
'show_total_output_tokens': True,
'ip_in_queue_max': 3,
'show_backend_info': True,
}
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']

View File

@ -45,11 +45,12 @@ def init_db():
conn.close()
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code):
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, response_tokens: int = None):
prompt_tokens = len(tokenizer.encode(prompt))
if not response_tokens:
response_tokens = len(tokenizer.encode(response))
# Sometimes we may want to insert null into the DB but
# Sometimes we may want to insert null into the DB, but
# usually we want to insert a float.
if gen_time:
gen_time = round(gen_time, 3)

View File

@ -1,3 +1,4 @@
import json
from collections import OrderedDict
from pathlib import Path
@ -39,3 +40,10 @@ def deep_sort(obj):
obj = sorted(obj, key=lambda x: json.dumps(x))
return obj
def indefinite_article(word):
if word[0].lower() in 'aeiou':
return 'an'
else:
return 'a'

View File

@ -1,14 +1,10 @@
import json
import requests
from flask import current_app
from llm_server import opts
from llm_server.database import tokenizer
def prepare_json(json_data: dict):
token_count = len(tokenizer.encode(json_data.get('prompt', '')))
# token_count = len(tokenizer.encode(json_data.get('prompt', '')))
seed = json_data.get('seed', None)
if seed == -1:
seed = None
@ -18,7 +14,7 @@ def prepare_json(json_data: dict):
return {
'inputs': json_data.get('prompt', ''),
'parameters': {
'max_new_tokens': opts.context_size - token_count,
'max_new_tokens': json_data.get('max_new_tokens'),
'repetition_penalty': json_data.get('repetition_penalty', None),
'seed': seed,
'stop': json_data.get('stopping_strings', []),
@ -27,16 +23,22 @@ def prepare_json(json_data: dict):
'top_p': json_data.get('top_p', None),
# 'truncate': opts.token_limit,
'typical_p': typical_p,
'watermark': False
'watermark': False,
'do_sample': json_data.get('do_sample', False),
'return_full_text': False,
'details': True,
}
}
def generate(json_data: dict):
print(json.dumps(prepare_json(json_data)))
# try:
# print(json.dumps(prepare_json(json_data)))
try:
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl)
print(r.text)
except Exception as e:
return False, None, f'{e.__class__.__name__}: {e}'
return True, r, None
# except Exception as e:
# return False, None, f'{e.__class__.__name__}: {e}'
# if r.status_code != 200:

View File

@ -20,3 +20,4 @@ average_generation_time_mode = 'database'
show_total_output_tokens = True
netdata_root = None
ip_in_queue_max = 3
show_backend_info = True

View File

@ -1,3 +1,5 @@
import json
from flask_caching import Cache
from redis import Redis
from redis.typing import FieldT
@ -37,6 +39,18 @@ class RedisWrapper:
def sismember(self, key: str, value: str):
return self.redis.sismember(f"{self.prefix}:{key}", value)
def set_dict(self, key, dict_value):
# return self.redis.hset(f"{self.prefix}:{key}", mapping=dict_value)
return self.set(f"{self.prefix}:{key}", json.dumps(dict_value))
def get_dict(self, key):
# return self.redis.hgetall(f"{self.prefix}:{key}")
r = self.get(f"{self.prefix}:{key}")
if not r:
return dict()
else:
return json.loads(r)
def flush(self):
flushed = []
for key in self.redis.scan_iter(f'{self.prefix}:*'):

View File

@ -7,8 +7,24 @@ from llm_server.llm.generator import generator
from llm_server.routes.cache import redis
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
processing_ips = set()
processing_ips_lock = threading.Lock()
redis.set_dict('processing_ips', {})
def increment_ip_count(client_ip: int, redis_key):
ip_count = redis.get_dict(redis_key)
ip_count[client_ip] = ip_count.get(client_ip, 0) + 1
redis.set_dict(redis_key, ip_count)
return ip_count
def decrement_ip_count(client_ip: int, redis_key):
ip_count = redis.get_dict(redis_key)
if client_ip in ip_count.keys():
ip_count[client_ip] -= 1
if ip_count[client_ip] == 0:
del ip_count[client_ip] # Remove the IP from the dictionary if count is 0
redis.set_dict(redis_key, ip_count)
return ip_count
class PriorityQueue:
@ -16,18 +32,21 @@ class PriorityQueue:
self._queue = []
self._index = 0
self._cv = threading.Condition()
self._ip_count = {}
self._lock = threading.Lock()
redis.set_dict('queued_ip_count', {})
def put(self, item, priority):
event = DataEvent()
with self._cv:
# Check if the IP is already in the dictionary and if it has reached the limit
if item[1] in self._ip_count and self._ip_count[item[1]] >= opts.ip_in_queue_max and priority != 0:
ip_count = redis.get_dict('queued_ip_count')
if item[1] in ip_count and ip_count[item[1]] >= opts.ip_in_queue_max and priority != 0:
return None # reject the request
heapq.heappush(self._queue, (-priority, self._index, item, event))
self._index += 1
# Increment the count for this IP
self._ip_count[item[1]] = self._ip_count.get(item[1], 0) + 1
with self._lock:
increment_ip_count(item[1], 'queued_ip_count')
self._cv.notify()
return event
@ -37,9 +56,8 @@ class PriorityQueue:
self._cv.wait()
_, _, item, event = heapq.heappop(self._queue)
# Decrement the count for this IP
self._ip_count[item[1]] -= 1
if self._ip_count[item[1]] == 0:
del self._ip_count[item[1]] # Remove the IP from the dictionary if count is 0
with self._lock:
decrement_ip_count(item[1], 'queued_ip_count')
return item, event
def __len__(self):
@ -60,13 +78,15 @@ def worker():
while True:
(request_json_body, client_ip, token, parameters), event = priority_queue.get()
redis.sadd('processing_ips', client_ip)
# redis.sadd('processing_ips', client_ip)
increment_ip_count(client_ip, 'processing_ips')
redis.incr('active_gen_workers')
start_time = time.time()
success, response, error_msg = generator(request_json_body)
end_time = time.time()
elapsed_time = end_time - start_time
with generation_elapsed_lock:
generation_elapsed.append((end_time, elapsed_time))
@ -74,7 +94,8 @@ def worker():
event.data = (success, response, error_msg)
event.set()
redis.srem('processing_ips', client_ip)
# redis.srem('processing_ips', client_ip)
decrement_ip_count(client_ip, 'processing_ips')
redis.decr('active_gen_workers')

View File

@ -11,11 +11,13 @@ from ..helpers.http import validate_json
from ..queue import priority_queue
from ... import opts
from ...database import log_prompt
from ...helpers import safe_list_get
from ...helpers import safe_list_get, indefinite_article
DEFAULT_PRIORITY = 9999
# TODO: clean this up and make the ooba vs hf-textgen more object-oriented
@bp.route('/generate', methods=['POST'])
def generate():
start_time = time.time()
@ -51,13 +53,13 @@ def generate():
else:
print(f'Token {token} was given priority {priority}.')
if not redis.sismember('processing_ips', client_ip) or priority == 0:
queued_ip_count = redis.get_dict('queued_ip_count').get(client_ip, 0) + redis.get_dict('processing_ips').get(client_ip, 0)
if queued_ip_count < opts.ip_in_queue_max or priority == 0:
event = priority_queue.put((request_json_body, client_ip, token, parameters), priority)
else:
event = None
if not event:
log_prompt(client_ip, token, request_json_body['prompt'], '', None, parameters, dict(request.headers), 429)
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.ip_in_queue_max} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
response_json_body = {
'results': [
@ -66,8 +68,6 @@ def generate():
}
],
}
else:
raise Exception
return jsonify({
**response_json_body
}), 200
@ -75,15 +75,11 @@ def generate():
event.wait()
success, response, error_msg = event.data
# Add the elapsed time to a global list
end_time = time.time()
elapsed_time = end_time - start_time
# print('elapsed:', elapsed_time)
# with wait_in_queue_elapsed_lock:
# wait_in_queue_elapsed.append((end_time, elapsed_time))
if not success or not response:
if opts.mode == 'oobabooga':
if (not success or not response) and opts.mode == 'oobabooga':
# Ooba doesn't return any error messages
backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error')
response_json_body = {
'results': [
@ -92,9 +88,6 @@ def generate():
}
],
}
else:
raise Exception
log_prompt(client_ip, token, request_json_body['prompt'], '', None, parameters, dict(request.headers), response if response else 0)
return jsonify({
'code': 500,
@ -103,23 +96,47 @@ def generate():
}), 200
response_valid_json, response_json_body = validate_json(response)
backend_err = False
# Return the result to the client
if response_valid_json:
redis.incr('proompts')
if opts.mode == 'oobabooga':
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
if not backend_response:
if opts.mode == 'oobabooga':
backend_err = True
backend_response = format_sillytavern_err(
f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
'error')
response_json_body['results'][0]['text'] = backend_response
elif opts.mode == 'hf-textgen':
backend_response = response_json_body.get('generated_text', '')
if response_json_body.get('error'):
error_type = response_json_body.get('error_type')
error_type_string = 'returned an error' if opts.mode == 'oobabooga' else f'returned {indefinite_article(error_type)} {error_type} error'
response_json_body = {
'results': [
{
'text': format_sillytavern_err(
f'Backend ({opts.mode}) {error_type_string}: {response_json_body.get("error")}',
'error')
}
]
}
else:
response_json_body = {
'results': [
{
'text': backend_response
}
]
}
else:
raise Exception
log_prompt(client_ip, token, request_json_body['prompt'], backend_response if not backend_err else '', elapsed_time if not backend_err else None, parameters, dict(request.headers), response.status_code)
redis.incr('proompts')
log_prompt(client_ip, token, request_json_body['prompt'], backend_response if not backend_err else '', elapsed_time if not backend_err else None, parameters, dict(request.headers), response.status_code if response else 0, response_json_body.get('details', {}).get('generated_tokens'))
return jsonify({
**response_json_body
}), 200
else:
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')

View File

@ -92,14 +92,15 @@ def generate_stats():
'config': {
'gatekeeper': 'none' if opts.auth_required is False else 'token',
'context_size': opts.context_size,
'queue_size': opts.concurrent_gens,
'concurrent': opts.concurrent_gens,
'model': model_name,
'mode': opts.mode,
'simultaneous_requests': opts.ip_in_queue_max,
'simultaneous_requests_per_ip': opts.ip_in_queue_max,
},
'keys': {
'openaiKeys': '',
'anthropicKeys': '',
},
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
}
return deep_sort(output)

View File

@ -4,7 +4,7 @@ from threading import Thread
import requests
from llm_server import opts
from llm_server.database import average_column_for_model, weighted_average_column_for_model
from llm_server.database import weighted_average_column_for_model
from llm_server.routes.cache import redis
@ -21,6 +21,7 @@ class MainBackgroundThread(Thread):
redis.set('average_tps', 0)
redis.set('average_output_tokens', 0)
redis.set('backend_online', 0)
redis.set_dict('backend_info', {})
def run(self):
while True:
@ -34,7 +35,16 @@ class MainBackgroundThread(Thread):
# TODO: handle error
print(e)
elif opts.mode == 'hf-textgen':
pass
try:
r = requests.get(f'{opts.backend_url}/info', timeout=3, verify=opts.verify_ssl)
j = r.json()
opts.running_model = j['model_id']
redis.set('backend_online', 1)
redis.set_dict('backend_info', j)
except Exception as e:
redis.set('backend_online', 0)
# TODO: handle error
print(e)
else:
raise Exception

View File

@ -53,6 +53,7 @@ opts.backend_url = config['backend_url'].strip('/')
opts.show_total_output_tokens = config['show_total_output_tokens']
opts.netdata_root = config['netdata_root']
opts.ip_in_queue_max = config['ip_in_queue_max']
opts.show_backend_info = config['show_backend_info']
opts.verify_ssl = config['verify_ssl']
if not opts.verify_ssl: