add HF text-generation-inference backend
This commit is contained in:
parent
6c0e60135d
commit
ba0bc87434
|
@ -14,6 +14,7 @@ config_default_vars = {
|
|||
'info_html': None,
|
||||
'show_total_output_tokens': True,
|
||||
'ip_in_queue_max': 3,
|
||||
'show_backend_info': True,
|
||||
}
|
||||
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
||||
|
||||
|
|
|
@ -45,11 +45,12 @@ def init_db():
|
|||
conn.close()
|
||||
|
||||
|
||||
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code):
|
||||
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, response_tokens: int = None):
|
||||
prompt_tokens = len(tokenizer.encode(prompt))
|
||||
if not response_tokens:
|
||||
response_tokens = len(tokenizer.encode(response))
|
||||
|
||||
# Sometimes we may want to insert null into the DB but
|
||||
# Sometimes we may want to insert null into the DB, but
|
||||
# usually we want to insert a float.
|
||||
if gen_time:
|
||||
gen_time = round(gen_time, 3)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -39,3 +40,10 @@ def deep_sort(obj):
|
|||
obj = sorted(obj, key=lambda x: json.dumps(x))
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
def indefinite_article(word):
|
||||
if word[0].lower() in 'aeiou':
|
||||
return 'an'
|
||||
else:
|
||||
return 'a'
|
||||
|
|
|
@ -1,14 +1,10 @@
|
|||
import json
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import tokenizer
|
||||
|
||||
|
||||
def prepare_json(json_data: dict):
|
||||
token_count = len(tokenizer.encode(json_data.get('prompt', '')))
|
||||
# token_count = len(tokenizer.encode(json_data.get('prompt', '')))
|
||||
seed = json_data.get('seed', None)
|
||||
if seed == -1:
|
||||
seed = None
|
||||
|
@ -18,7 +14,7 @@ def prepare_json(json_data: dict):
|
|||
return {
|
||||
'inputs': json_data.get('prompt', ''),
|
||||
'parameters': {
|
||||
'max_new_tokens': opts.context_size - token_count,
|
||||
'max_new_tokens': json_data.get('max_new_tokens'),
|
||||
'repetition_penalty': json_data.get('repetition_penalty', None),
|
||||
'seed': seed,
|
||||
'stop': json_data.get('stopping_strings', []),
|
||||
|
@ -27,16 +23,22 @@ def prepare_json(json_data: dict):
|
|||
'top_p': json_data.get('top_p', None),
|
||||
# 'truncate': opts.token_limit,
|
||||
'typical_p': typical_p,
|
||||
'watermark': False
|
||||
'watermark': False,
|
||||
'do_sample': json_data.get('do_sample', False),
|
||||
'return_full_text': False,
|
||||
'details': True,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def generate(json_data: dict):
|
||||
print(json.dumps(prepare_json(json_data)))
|
||||
# try:
|
||||
# print(json.dumps(prepare_json(json_data)))
|
||||
try:
|
||||
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl)
|
||||
print(r.text)
|
||||
except Exception as e:
|
||||
return False, None, f'{e.__class__.__name__}: {e}'
|
||||
return True, r, None
|
||||
|
||||
# except Exception as e:
|
||||
# return False, None, f'{e.__class__.__name__}: {e}'
|
||||
# if r.status_code != 200:
|
||||
|
|
|
@ -20,3 +20,4 @@ average_generation_time_mode = 'database'
|
|||
show_total_output_tokens = True
|
||||
netdata_root = None
|
||||
ip_in_queue_max = 3
|
||||
show_backend_info = True
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import json
|
||||
|
||||
from flask_caching import Cache
|
||||
from redis import Redis
|
||||
from redis.typing import FieldT
|
||||
|
@ -37,6 +39,18 @@ class RedisWrapper:
|
|||
def sismember(self, key: str, value: str):
|
||||
return self.redis.sismember(f"{self.prefix}:{key}", value)
|
||||
|
||||
def set_dict(self, key, dict_value):
|
||||
# return self.redis.hset(f"{self.prefix}:{key}", mapping=dict_value)
|
||||
return self.set(f"{self.prefix}:{key}", json.dumps(dict_value))
|
||||
|
||||
def get_dict(self, key):
|
||||
# return self.redis.hgetall(f"{self.prefix}:{key}")
|
||||
r = self.get(f"{self.prefix}:{key}")
|
||||
if not r:
|
||||
return dict()
|
||||
else:
|
||||
return json.loads(r)
|
||||
|
||||
def flush(self):
|
||||
flushed = []
|
||||
for key in self.redis.scan_iter(f'{self.prefix}:*'):
|
||||
|
|
|
@ -7,8 +7,24 @@ from llm_server.llm.generator import generator
|
|||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
|
||||
|
||||
processing_ips = set()
|
||||
processing_ips_lock = threading.Lock()
|
||||
redis.set_dict('processing_ips', {})
|
||||
|
||||
|
||||
def increment_ip_count(client_ip: int, redis_key):
|
||||
ip_count = redis.get_dict(redis_key)
|
||||
ip_count[client_ip] = ip_count.get(client_ip, 0) + 1
|
||||
redis.set_dict(redis_key, ip_count)
|
||||
return ip_count
|
||||
|
||||
|
||||
def decrement_ip_count(client_ip: int, redis_key):
|
||||
ip_count = redis.get_dict(redis_key)
|
||||
if client_ip in ip_count.keys():
|
||||
ip_count[client_ip] -= 1
|
||||
if ip_count[client_ip] == 0:
|
||||
del ip_count[client_ip] # Remove the IP from the dictionary if count is 0
|
||||
redis.set_dict(redis_key, ip_count)
|
||||
return ip_count
|
||||
|
||||
|
||||
class PriorityQueue:
|
||||
|
@ -16,18 +32,21 @@ class PriorityQueue:
|
|||
self._queue = []
|
||||
self._index = 0
|
||||
self._cv = threading.Condition()
|
||||
self._ip_count = {}
|
||||
self._lock = threading.Lock()
|
||||
redis.set_dict('queued_ip_count', {})
|
||||
|
||||
def put(self, item, priority):
|
||||
event = DataEvent()
|
||||
with self._cv:
|
||||
# Check if the IP is already in the dictionary and if it has reached the limit
|
||||
if item[1] in self._ip_count and self._ip_count[item[1]] >= opts.ip_in_queue_max and priority != 0:
|
||||
ip_count = redis.get_dict('queued_ip_count')
|
||||
if item[1] in ip_count and ip_count[item[1]] >= opts.ip_in_queue_max and priority != 0:
|
||||
return None # reject the request
|
||||
heapq.heappush(self._queue, (-priority, self._index, item, event))
|
||||
self._index += 1
|
||||
# Increment the count for this IP
|
||||
self._ip_count[item[1]] = self._ip_count.get(item[1], 0) + 1
|
||||
with self._lock:
|
||||
increment_ip_count(item[1], 'queued_ip_count')
|
||||
self._cv.notify()
|
||||
return event
|
||||
|
||||
|
@ -37,9 +56,8 @@ class PriorityQueue:
|
|||
self._cv.wait()
|
||||
_, _, item, event = heapq.heappop(self._queue)
|
||||
# Decrement the count for this IP
|
||||
self._ip_count[item[1]] -= 1
|
||||
if self._ip_count[item[1]] == 0:
|
||||
del self._ip_count[item[1]] # Remove the IP from the dictionary if count is 0
|
||||
with self._lock:
|
||||
decrement_ip_count(item[1], 'queued_ip_count')
|
||||
return item, event
|
||||
|
||||
def __len__(self):
|
||||
|
@ -60,13 +78,15 @@ def worker():
|
|||
while True:
|
||||
(request_json_body, client_ip, token, parameters), event = priority_queue.get()
|
||||
|
||||
redis.sadd('processing_ips', client_ip)
|
||||
# redis.sadd('processing_ips', client_ip)
|
||||
increment_ip_count(client_ip, 'processing_ips')
|
||||
|
||||
redis.incr('active_gen_workers')
|
||||
|
||||
start_time = time.time()
|
||||
success, response, error_msg = generator(request_json_body)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
elapsed_time = end_time - start_time
|
||||
with generation_elapsed_lock:
|
||||
generation_elapsed.append((end_time, elapsed_time))
|
||||
|
@ -74,7 +94,8 @@ def worker():
|
|||
event.data = (success, response, error_msg)
|
||||
event.set()
|
||||
|
||||
redis.srem('processing_ips', client_ip)
|
||||
# redis.srem('processing_ips', client_ip)
|
||||
decrement_ip_count(client_ip, 'processing_ips')
|
||||
redis.decr('active_gen_workers')
|
||||
|
||||
|
||||
|
|
|
@ -11,11 +11,13 @@ from ..helpers.http import validate_json
|
|||
from ..queue import priority_queue
|
||||
from ... import opts
|
||||
from ...database import log_prompt
|
||||
from ...helpers import safe_list_get
|
||||
from ...helpers import safe_list_get, indefinite_article
|
||||
|
||||
DEFAULT_PRIORITY = 9999
|
||||
|
||||
|
||||
# TODO: clean this up and make the ooba vs hf-textgen more object-oriented
|
||||
|
||||
@bp.route('/generate', methods=['POST'])
|
||||
def generate():
|
||||
start_time = time.time()
|
||||
|
@ -51,13 +53,13 @@ def generate():
|
|||
else:
|
||||
print(f'Token {token} was given priority {priority}.')
|
||||
|
||||
if not redis.sismember('processing_ips', client_ip) or priority == 0:
|
||||
queued_ip_count = redis.get_dict('queued_ip_count').get(client_ip, 0) + redis.get_dict('processing_ips').get(client_ip, 0)
|
||||
if queued_ip_count < opts.ip_in_queue_max or priority == 0:
|
||||
event = priority_queue.put((request_json_body, client_ip, token, parameters), priority)
|
||||
else:
|
||||
event = None
|
||||
if not event:
|
||||
log_prompt(client_ip, token, request_json_body['prompt'], '', None, parameters, dict(request.headers), 429)
|
||||
if opts.mode == 'oobabooga':
|
||||
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.ip_in_queue_max} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
|
||||
response_json_body = {
|
||||
'results': [
|
||||
|
@ -66,8 +68,6 @@ def generate():
|
|||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
raise Exception
|
||||
return jsonify({
|
||||
**response_json_body
|
||||
}), 200
|
||||
|
@ -75,15 +75,11 @@ def generate():
|
|||
event.wait()
|
||||
success, response, error_msg = event.data
|
||||
|
||||
# Add the elapsed time to a global list
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
# print('elapsed:', elapsed_time)
|
||||
# with wait_in_queue_elapsed_lock:
|
||||
# wait_in_queue_elapsed.append((end_time, elapsed_time))
|
||||
|
||||
if not success or not response:
|
||||
if opts.mode == 'oobabooga':
|
||||
if (not success or not response) and opts.mode == 'oobabooga':
|
||||
# Ooba doesn't return any error messages
|
||||
backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error')
|
||||
response_json_body = {
|
||||
'results': [
|
||||
|
@ -92,9 +88,6 @@ def generate():
|
|||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
log_prompt(client_ip, token, request_json_body['prompt'], '', None, parameters, dict(request.headers), response if response else 0)
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
|
@ -103,23 +96,47 @@ def generate():
|
|||
}), 200
|
||||
response_valid_json, response_json_body = validate_json(response)
|
||||
backend_err = False
|
||||
|
||||
# Return the result to the client
|
||||
if response_valid_json:
|
||||
redis.incr('proompts')
|
||||
if opts.mode == 'oobabooga':
|
||||
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
|
||||
if not backend_response:
|
||||
if opts.mode == 'oobabooga':
|
||||
backend_err = True
|
||||
backend_response = format_sillytavern_err(
|
||||
f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
|
||||
'error')
|
||||
response_json_body['results'][0]['text'] = backend_response
|
||||
elif opts.mode == 'hf-textgen':
|
||||
backend_response = response_json_body.get('generated_text', '')
|
||||
if response_json_body.get('error'):
|
||||
error_type = response_json_body.get('error_type')
|
||||
error_type_string = 'returned an error' if opts.mode == 'oobabooga' else f'returned {indefinite_article(error_type)} {error_type} error'
|
||||
response_json_body = {
|
||||
'results': [
|
||||
{
|
||||
'text': format_sillytavern_err(
|
||||
f'Backend ({opts.mode}) {error_type_string}: {response_json_body.get("error")}',
|
||||
'error')
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
response_json_body = {
|
||||
'results': [
|
||||
{
|
||||
'text': backend_response
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
log_prompt(client_ip, token, request_json_body['prompt'], backend_response if not backend_err else '', elapsed_time if not backend_err else None, parameters, dict(request.headers), response.status_code)
|
||||
redis.incr('proompts')
|
||||
log_prompt(client_ip, token, request_json_body['prompt'], backend_response if not backend_err else '', elapsed_time if not backend_err else None, parameters, dict(request.headers), response.status_code if response else 0, response_json_body.get('details', {}).get('generated_tokens'))
|
||||
return jsonify({
|
||||
**response_json_body
|
||||
}), 200
|
||||
|
||||
else:
|
||||
if opts.mode == 'oobabooga':
|
||||
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
|
||||
|
|
|
@ -92,14 +92,15 @@ def generate_stats():
|
|||
'config': {
|
||||
'gatekeeper': 'none' if opts.auth_required is False else 'token',
|
||||
'context_size': opts.context_size,
|
||||
'queue_size': opts.concurrent_gens,
|
||||
'concurrent': opts.concurrent_gens,
|
||||
'model': model_name,
|
||||
'mode': opts.mode,
|
||||
'simultaneous_requests': opts.ip_in_queue_max,
|
||||
'simultaneous_requests_per_ip': opts.ip_in_queue_max,
|
||||
},
|
||||
'keys': {
|
||||
'openaiKeys': '∞',
|
||||
'anthropicKeys': '∞',
|
||||
},
|
||||
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
|
||||
}
|
||||
return deep_sort(output)
|
||||
|
|
|
@ -4,7 +4,7 @@ from threading import Thread
|
|||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import average_column_for_model, weighted_average_column_for_model
|
||||
from llm_server.database import weighted_average_column_for_model
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
|
||||
|
@ -21,6 +21,7 @@ class MainBackgroundThread(Thread):
|
|||
redis.set('average_tps', 0)
|
||||
redis.set('average_output_tokens', 0)
|
||||
redis.set('backend_online', 0)
|
||||
redis.set_dict('backend_info', {})
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
|
@ -34,7 +35,16 @@ class MainBackgroundThread(Thread):
|
|||
# TODO: handle error
|
||||
print(e)
|
||||
elif opts.mode == 'hf-textgen':
|
||||
pass
|
||||
try:
|
||||
r = requests.get(f'{opts.backend_url}/info', timeout=3, verify=opts.verify_ssl)
|
||||
j = r.json()
|
||||
opts.running_model = j['model_id']
|
||||
redis.set('backend_online', 1)
|
||||
redis.set_dict('backend_info', j)
|
||||
except Exception as e:
|
||||
redis.set('backend_online', 0)
|
||||
# TODO: handle error
|
||||
print(e)
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
|
|
|
@ -53,6 +53,7 @@ opts.backend_url = config['backend_url'].strip('/')
|
|||
opts.show_total_output_tokens = config['show_total_output_tokens']
|
||||
opts.netdata_root = config['netdata_root']
|
||||
opts.ip_in_queue_max = config['ip_in_queue_max']
|
||||
opts.show_backend_info = config['show_backend_info']
|
||||
|
||||
opts.verify_ssl = config['verify_ssl']
|
||||
if not opts.verify_ssl:
|
||||
|
|
Reference in New Issue