add HF text-generation-inference backend

This commit is contained in:
Cyberes 2023-08-29 13:46:41 -06:00
parent 6c0e60135d
commit ba0bc87434
11 changed files with 148 additions and 71 deletions

View File

@ -14,6 +14,7 @@ config_default_vars = {
'info_html': None, 'info_html': None,
'show_total_output_tokens': True, 'show_total_output_tokens': True,
'ip_in_queue_max': 3, 'ip_in_queue_max': 3,
'show_backend_info': True,
} }
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name'] config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']

View File

@ -45,11 +45,12 @@ def init_db():
conn.close() conn.close()
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code): def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, response_tokens: int = None):
prompt_tokens = len(tokenizer.encode(prompt)) prompt_tokens = len(tokenizer.encode(prompt))
if not response_tokens:
response_tokens = len(tokenizer.encode(response)) response_tokens = len(tokenizer.encode(response))
# Sometimes we may want to insert null into the DB but # Sometimes we may want to insert null into the DB, but
# usually we want to insert a float. # usually we want to insert a float.
if gen_time: if gen_time:
gen_time = round(gen_time, 3) gen_time = round(gen_time, 3)

View File

@ -1,3 +1,4 @@
import json
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path from pathlib import Path
@ -39,3 +40,10 @@ def deep_sort(obj):
obj = sorted(obj, key=lambda x: json.dumps(x)) obj = sorted(obj, key=lambda x: json.dumps(x))
return obj return obj
def indefinite_article(word):
if word[0].lower() in 'aeiou':
return 'an'
else:
return 'a'

View File

@ -1,14 +1,10 @@
import json
import requests import requests
from flask import current_app
from llm_server import opts from llm_server import opts
from llm_server.database import tokenizer
def prepare_json(json_data: dict): def prepare_json(json_data: dict):
token_count = len(tokenizer.encode(json_data.get('prompt', ''))) # token_count = len(tokenizer.encode(json_data.get('prompt', '')))
seed = json_data.get('seed', None) seed = json_data.get('seed', None)
if seed == -1: if seed == -1:
seed = None seed = None
@ -18,7 +14,7 @@ def prepare_json(json_data: dict):
return { return {
'inputs': json_data.get('prompt', ''), 'inputs': json_data.get('prompt', ''),
'parameters': { 'parameters': {
'max_new_tokens': opts.context_size - token_count, 'max_new_tokens': json_data.get('max_new_tokens'),
'repetition_penalty': json_data.get('repetition_penalty', None), 'repetition_penalty': json_data.get('repetition_penalty', None),
'seed': seed, 'seed': seed,
'stop': json_data.get('stopping_strings', []), 'stop': json_data.get('stopping_strings', []),
@ -27,18 +23,24 @@ def prepare_json(json_data: dict):
'top_p': json_data.get('top_p', None), 'top_p': json_data.get('top_p', None),
# 'truncate': opts.token_limit, # 'truncate': opts.token_limit,
'typical_p': typical_p, 'typical_p': typical_p,
'watermark': False 'watermark': False,
'do_sample': json_data.get('do_sample', False),
'return_full_text': False,
'details': True,
} }
} }
def generate(json_data: dict): def generate(json_data: dict):
print(json.dumps(prepare_json(json_data))) # print(json.dumps(prepare_json(json_data)))
# try: try:
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl) r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl)
print(r.text) except Exception as e:
# except Exception as e: return False, None, f'{e.__class__.__name__}: {e}'
# return False, None, f'{e.__class__.__name__}: {e}' return True, r, None
# if r.status_code != 200:
# return False, r, f'Backend returned {r.status_code}' # except Exception as e:
# return True, r, None # return False, None, f'{e.__class__.__name__}: {e}'
# if r.status_code != 200:
# return False, r, f'Backend returned {r.status_code}'
# return True, r, None

View File

@ -20,3 +20,4 @@ average_generation_time_mode = 'database'
show_total_output_tokens = True show_total_output_tokens = True
netdata_root = None netdata_root = None
ip_in_queue_max = 3 ip_in_queue_max = 3
show_backend_info = True

View File

@ -1,3 +1,5 @@
import json
from flask_caching import Cache from flask_caching import Cache
from redis import Redis from redis import Redis
from redis.typing import FieldT from redis.typing import FieldT
@ -37,6 +39,18 @@ class RedisWrapper:
def sismember(self, key: str, value: str): def sismember(self, key: str, value: str):
return self.redis.sismember(f"{self.prefix}:{key}", value) return self.redis.sismember(f"{self.prefix}:{key}", value)
def set_dict(self, key, dict_value):
# return self.redis.hset(f"{self.prefix}:{key}", mapping=dict_value)
return self.set(f"{self.prefix}:{key}", json.dumps(dict_value))
def get_dict(self, key):
# return self.redis.hgetall(f"{self.prefix}:{key}")
r = self.get(f"{self.prefix}:{key}")
if not r:
return dict()
else:
return json.loads(r)
def flush(self): def flush(self):
flushed = [] flushed = []
for key in self.redis.scan_iter(f'{self.prefix}:*'): for key in self.redis.scan_iter(f'{self.prefix}:*'):

View File

@ -7,8 +7,24 @@ from llm_server.llm.generator import generator
from llm_server.routes.cache import redis from llm_server.routes.cache import redis
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
processing_ips = set() redis.set_dict('processing_ips', {})
processing_ips_lock = threading.Lock()
def increment_ip_count(client_ip: int, redis_key):
ip_count = redis.get_dict(redis_key)
ip_count[client_ip] = ip_count.get(client_ip, 0) + 1
redis.set_dict(redis_key, ip_count)
return ip_count
def decrement_ip_count(client_ip: int, redis_key):
ip_count = redis.get_dict(redis_key)
if client_ip in ip_count.keys():
ip_count[client_ip] -= 1
if ip_count[client_ip] == 0:
del ip_count[client_ip] # Remove the IP from the dictionary if count is 0
redis.set_dict(redis_key, ip_count)
return ip_count
class PriorityQueue: class PriorityQueue:
@ -16,18 +32,21 @@ class PriorityQueue:
self._queue = [] self._queue = []
self._index = 0 self._index = 0
self._cv = threading.Condition() self._cv = threading.Condition()
self._ip_count = {} self._lock = threading.Lock()
redis.set_dict('queued_ip_count', {})
def put(self, item, priority): def put(self, item, priority):
event = DataEvent() event = DataEvent()
with self._cv: with self._cv:
# Check if the IP is already in the dictionary and if it has reached the limit # Check if the IP is already in the dictionary and if it has reached the limit
if item[1] in self._ip_count and self._ip_count[item[1]] >= opts.ip_in_queue_max and priority != 0: ip_count = redis.get_dict('queued_ip_count')
if item[1] in ip_count and ip_count[item[1]] >= opts.ip_in_queue_max and priority != 0:
return None # reject the request return None # reject the request
heapq.heappush(self._queue, (-priority, self._index, item, event)) heapq.heappush(self._queue, (-priority, self._index, item, event))
self._index += 1 self._index += 1
# Increment the count for this IP # Increment the count for this IP
self._ip_count[item[1]] = self._ip_count.get(item[1], 0) + 1 with self._lock:
increment_ip_count(item[1], 'queued_ip_count')
self._cv.notify() self._cv.notify()
return event return event
@ -37,9 +56,8 @@ class PriorityQueue:
self._cv.wait() self._cv.wait()
_, _, item, event = heapq.heappop(self._queue) _, _, item, event = heapq.heappop(self._queue)
# Decrement the count for this IP # Decrement the count for this IP
self._ip_count[item[1]] -= 1 with self._lock:
if self._ip_count[item[1]] == 0: decrement_ip_count(item[1], 'queued_ip_count')
del self._ip_count[item[1]] # Remove the IP from the dictionary if count is 0
return item, event return item, event
def __len__(self): def __len__(self):
@ -60,13 +78,15 @@ def worker():
while True: while True:
(request_json_body, client_ip, token, parameters), event = priority_queue.get() (request_json_body, client_ip, token, parameters), event = priority_queue.get()
redis.sadd('processing_ips', client_ip) # redis.sadd('processing_ips', client_ip)
increment_ip_count(client_ip, 'processing_ips')
redis.incr('active_gen_workers') redis.incr('active_gen_workers')
start_time = time.time() start_time = time.time()
success, response, error_msg = generator(request_json_body) success, response, error_msg = generator(request_json_body)
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
with generation_elapsed_lock: with generation_elapsed_lock:
generation_elapsed.append((end_time, elapsed_time)) generation_elapsed.append((end_time, elapsed_time))
@ -74,7 +94,8 @@ def worker():
event.data = (success, response, error_msg) event.data = (success, response, error_msg)
event.set() event.set()
redis.srem('processing_ips', client_ip) # redis.srem('processing_ips', client_ip)
decrement_ip_count(client_ip, 'processing_ips')
redis.decr('active_gen_workers') redis.decr('active_gen_workers')

View File

@ -11,11 +11,13 @@ from ..helpers.http import validate_json
from ..queue import priority_queue from ..queue import priority_queue
from ... import opts from ... import opts
from ...database import log_prompt from ...database import log_prompt
from ...helpers import safe_list_get from ...helpers import safe_list_get, indefinite_article
DEFAULT_PRIORITY = 9999 DEFAULT_PRIORITY = 9999
# TODO: clean this up and make the ooba vs hf-textgen more object-oriented
@bp.route('/generate', methods=['POST']) @bp.route('/generate', methods=['POST'])
def generate(): def generate():
start_time = time.time() start_time = time.time()
@ -51,13 +53,13 @@ def generate():
else: else:
print(f'Token {token} was given priority {priority}.') print(f'Token {token} was given priority {priority}.')
if not redis.sismember('processing_ips', client_ip) or priority == 0: queued_ip_count = redis.get_dict('queued_ip_count').get(client_ip, 0) + redis.get_dict('processing_ips').get(client_ip, 0)
if queued_ip_count < opts.ip_in_queue_max or priority == 0:
event = priority_queue.put((request_json_body, client_ip, token, parameters), priority) event = priority_queue.put((request_json_body, client_ip, token, parameters), priority)
else: else:
event = None event = None
if not event: if not event:
log_prompt(client_ip, token, request_json_body['prompt'], '', None, parameters, dict(request.headers), 429) log_prompt(client_ip, token, request_json_body['prompt'], '', None, parameters, dict(request.headers), 429)
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.ip_in_queue_max} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.ip_in_queue_max} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
response_json_body = { response_json_body = {
'results': [ 'results': [
@ -66,8 +68,6 @@ def generate():
} }
], ],
} }
else:
raise Exception
return jsonify({ return jsonify({
**response_json_body **response_json_body
}), 200 }), 200
@ -75,15 +75,11 @@ def generate():
event.wait() event.wait()
success, response, error_msg = event.data success, response, error_msg = event.data
# Add the elapsed time to a global list
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
# print('elapsed:', elapsed_time)
# with wait_in_queue_elapsed_lock:
# wait_in_queue_elapsed.append((end_time, elapsed_time))
if not success or not response: if (not success or not response) and opts.mode == 'oobabooga':
if opts.mode == 'oobabooga': # Ooba doesn't return any error messages
backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error') backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error')
response_json_body = { response_json_body = {
'results': [ 'results': [
@ -92,9 +88,6 @@ def generate():
} }
], ],
} }
else:
raise Exception
log_prompt(client_ip, token, request_json_body['prompt'], '', None, parameters, dict(request.headers), response if response else 0) log_prompt(client_ip, token, request_json_body['prompt'], '', None, parameters, dict(request.headers), response if response else 0)
return jsonify({ return jsonify({
'code': 500, 'code': 500,
@ -103,23 +96,47 @@ def generate():
}), 200 }), 200
response_valid_json, response_json_body = validate_json(response) response_valid_json, response_json_body = validate_json(response)
backend_err = False backend_err = False
# Return the result to the client
if response_valid_json: if response_valid_json:
redis.incr('proompts') if opts.mode == 'oobabooga':
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text') backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
if not backend_response: if not backend_response:
if opts.mode == 'oobabooga':
backend_err = True backend_err = True
backend_response = format_sillytavern_err( backend_response = format_sillytavern_err(
f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.', f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
'error') 'error')
response_json_body['results'][0]['text'] = backend_response response_json_body['results'][0]['text'] = backend_response
elif opts.mode == 'hf-textgen':
backend_response = response_json_body.get('generated_text', '')
if response_json_body.get('error'):
error_type = response_json_body.get('error_type')
error_type_string = 'returned an error' if opts.mode == 'oobabooga' else f'returned {indefinite_article(error_type)} {error_type} error'
response_json_body = {
'results': [
{
'text': format_sillytavern_err(
f'Backend ({opts.mode}) {error_type_string}: {response_json_body.get("error")}',
'error')
}
]
}
else:
response_json_body = {
'results': [
{
'text': backend_response
}
]
}
else: else:
raise Exception raise Exception
redis.incr('proompts')
log_prompt(client_ip, token, request_json_body['prompt'], backend_response if not backend_err else '', elapsed_time if not backend_err else None, parameters, dict(request.headers), response.status_code) log_prompt(client_ip, token, request_json_body['prompt'], backend_response if not backend_err else '', elapsed_time if not backend_err else None, parameters, dict(request.headers), response.status_code if response else 0, response_json_body.get('details', {}).get('generated_tokens'))
return jsonify({ return jsonify({
**response_json_body **response_json_body
}), 200 }), 200
else: else:
if opts.mode == 'oobabooga': if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error') backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')

View File

@ -92,14 +92,15 @@ def generate_stats():
'config': { 'config': {
'gatekeeper': 'none' if opts.auth_required is False else 'token', 'gatekeeper': 'none' if opts.auth_required is False else 'token',
'context_size': opts.context_size, 'context_size': opts.context_size,
'queue_size': opts.concurrent_gens, 'concurrent': opts.concurrent_gens,
'model': model_name, 'model': model_name,
'mode': opts.mode, 'mode': opts.mode,
'simultaneous_requests': opts.ip_in_queue_max, 'simultaneous_requests_per_ip': opts.ip_in_queue_max,
}, },
'keys': { 'keys': {
'openaiKeys': '', 'openaiKeys': '',
'anthropicKeys': '', 'anthropicKeys': '',
}, },
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
} }
return deep_sort(output) return deep_sort(output)

View File

@ -4,7 +4,7 @@ from threading import Thread
import requests import requests
from llm_server import opts from llm_server import opts
from llm_server.database import average_column_for_model, weighted_average_column_for_model from llm_server.database import weighted_average_column_for_model
from llm_server.routes.cache import redis from llm_server.routes.cache import redis
@ -21,6 +21,7 @@ class MainBackgroundThread(Thread):
redis.set('average_tps', 0) redis.set('average_tps', 0)
redis.set('average_output_tokens', 0) redis.set('average_output_tokens', 0)
redis.set('backend_online', 0) redis.set('backend_online', 0)
redis.set_dict('backend_info', {})
def run(self): def run(self):
while True: while True:
@ -34,7 +35,16 @@ class MainBackgroundThread(Thread):
# TODO: handle error # TODO: handle error
print(e) print(e)
elif opts.mode == 'hf-textgen': elif opts.mode == 'hf-textgen':
pass try:
r = requests.get(f'{opts.backend_url}/info', timeout=3, verify=opts.verify_ssl)
j = r.json()
opts.running_model = j['model_id']
redis.set('backend_online', 1)
redis.set_dict('backend_info', j)
except Exception as e:
redis.set('backend_online', 0)
# TODO: handle error
print(e)
else: else:
raise Exception raise Exception

View File

@ -53,6 +53,7 @@ opts.backend_url = config['backend_url'].strip('/')
opts.show_total_output_tokens = config['show_total_output_tokens'] opts.show_total_output_tokens = config['show_total_output_tokens']
opts.netdata_root = config['netdata_root'] opts.netdata_root = config['netdata_root']
opts.ip_in_queue_max = config['ip_in_queue_max'] opts.ip_in_queue_max = config['ip_in_queue_max']
opts.show_backend_info = config['show_backend_info']
opts.verify_ssl = config['verify_ssl'] opts.verify_ssl = config['verify_ssl']
if not opts.verify_ssl: if not opts.verify_ssl: