From ba0bc874340e026b6806e14c600020dcb41e020d Mon Sep 17 00:00:00 2001 From: Cyberes Date: Tue, 29 Aug 2023 13:46:41 -0600 Subject: [PATCH] add HF text-generation-inference backend --- llm_server/config.py | 1 + llm_server/database.py | 7 +- llm_server/helpers.py | 8 +++ llm_server/llm/hf_textgen/generate.py | 34 +++++----- llm_server/opts.py | 1 + llm_server/routes/cache.py | 14 ++++ llm_server/routes/queue.py | 43 ++++++++---- llm_server/routes/v1/generate.py | 91 +++++++++++++++----------- llm_server/routes/v1/generate_stats.py | 5 +- llm_server/threads.py | 14 +++- server.py | 1 + 11 files changed, 148 insertions(+), 71 deletions(-) diff --git a/llm_server/config.py b/llm_server/config.py index 9c8edfc..ea06815 100644 --- a/llm_server/config.py +++ b/llm_server/config.py @@ -14,6 +14,7 @@ config_default_vars = { 'info_html': None, 'show_total_output_tokens': True, 'ip_in_queue_max': 3, + 'show_backend_info': True, } config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name'] diff --git a/llm_server/database.py b/llm_server/database.py index db347da..1edde9a 100644 --- a/llm_server/database.py +++ b/llm_server/database.py @@ -45,11 +45,12 @@ def init_db(): conn.close() -def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code): +def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, response_tokens: int = None): prompt_tokens = len(tokenizer.encode(prompt)) - response_tokens = len(tokenizer.encode(response)) + if not response_tokens: + response_tokens = len(tokenizer.encode(response)) - # Sometimes we may want to insert null into the DB but + # Sometimes we may want to insert null into the DB, but # usually we want to insert a float. if gen_time: gen_time = round(gen_time, 3) diff --git a/llm_server/helpers.py b/llm_server/helpers.py index c7f42a0..40dd81c 100644 --- a/llm_server/helpers.py +++ b/llm_server/helpers.py @@ -1,3 +1,4 @@ +import json from collections import OrderedDict from pathlib import Path @@ -39,3 +40,10 @@ def deep_sort(obj): obj = sorted(obj, key=lambda x: json.dumps(x)) return obj + + +def indefinite_article(word): + if word[0].lower() in 'aeiou': + return 'an' + else: + return 'a' diff --git a/llm_server/llm/hf_textgen/generate.py b/llm_server/llm/hf_textgen/generate.py index 385263e..ee20c04 100644 --- a/llm_server/llm/hf_textgen/generate.py +++ b/llm_server/llm/hf_textgen/generate.py @@ -1,14 +1,10 @@ -import json - import requests -from flask import current_app from llm_server import opts -from llm_server.database import tokenizer def prepare_json(json_data: dict): - token_count = len(tokenizer.encode(json_data.get('prompt', ''))) + # token_count = len(tokenizer.encode(json_data.get('prompt', ''))) seed = json_data.get('seed', None) if seed == -1: seed = None @@ -18,7 +14,7 @@ def prepare_json(json_data: dict): return { 'inputs': json_data.get('prompt', ''), 'parameters': { - 'max_new_tokens': opts.context_size - token_count, + 'max_new_tokens': json_data.get('max_new_tokens'), 'repetition_penalty': json_data.get('repetition_penalty', None), 'seed': seed, 'stop': json_data.get('stopping_strings', []), @@ -27,18 +23,24 @@ def prepare_json(json_data: dict): 'top_p': json_data.get('top_p', None), # 'truncate': opts.token_limit, 'typical_p': typical_p, - 'watermark': False + 'watermark': False, + 'do_sample': json_data.get('do_sample', False), + 'return_full_text': False, + 'details': True, } } def generate(json_data: dict): - print(json.dumps(prepare_json(json_data))) - # try: - r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl) - print(r.text) - # except Exception as e: - # return False, None, f'{e.__class__.__name__}: {e}' - # if r.status_code != 200: - # return False, r, f'Backend returned {r.status_code}' - # return True, r, None + # print(json.dumps(prepare_json(json_data))) + try: + r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl) + except Exception as e: + return False, None, f'{e.__class__.__name__}: {e}' + return True, r, None + +# except Exception as e: +# return False, None, f'{e.__class__.__name__}: {e}' +# if r.status_code != 200: +# return False, r, f'Backend returned {r.status_code}' +# return True, r, None diff --git a/llm_server/opts.py b/llm_server/opts.py index bc71047..bfa0b72 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -20,3 +20,4 @@ average_generation_time_mode = 'database' show_total_output_tokens = True netdata_root = None ip_in_queue_max = 3 +show_backend_info = True diff --git a/llm_server/routes/cache.py b/llm_server/routes/cache.py index ec6aed2..727ea72 100644 --- a/llm_server/routes/cache.py +++ b/llm_server/routes/cache.py @@ -1,3 +1,5 @@ +import json + from flask_caching import Cache from redis import Redis from redis.typing import FieldT @@ -37,6 +39,18 @@ class RedisWrapper: def sismember(self, key: str, value: str): return self.redis.sismember(f"{self.prefix}:{key}", value) + def set_dict(self, key, dict_value): + # return self.redis.hset(f"{self.prefix}:{key}", mapping=dict_value) + return self.set(f"{self.prefix}:{key}", json.dumps(dict_value)) + + def get_dict(self, key): + # return self.redis.hgetall(f"{self.prefix}:{key}") + r = self.get(f"{self.prefix}:{key}") + if not r: + return dict() + else: + return json.loads(r) + def flush(self): flushed = [] for key in self.redis.scan_iter(f'{self.prefix}:*'): diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index d38635f..f104a6f 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -7,8 +7,24 @@ from llm_server.llm.generator import generator from llm_server.routes.cache import redis from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock -processing_ips = set() -processing_ips_lock = threading.Lock() +redis.set_dict('processing_ips', {}) + + +def increment_ip_count(client_ip: int, redis_key): + ip_count = redis.get_dict(redis_key) + ip_count[client_ip] = ip_count.get(client_ip, 0) + 1 + redis.set_dict(redis_key, ip_count) + return ip_count + + +def decrement_ip_count(client_ip: int, redis_key): + ip_count = redis.get_dict(redis_key) + if client_ip in ip_count.keys(): + ip_count[client_ip] -= 1 + if ip_count[client_ip] == 0: + del ip_count[client_ip] # Remove the IP from the dictionary if count is 0 + redis.set_dict(redis_key, ip_count) + return ip_count class PriorityQueue: @@ -16,18 +32,21 @@ class PriorityQueue: self._queue = [] self._index = 0 self._cv = threading.Condition() - self._ip_count = {} + self._lock = threading.Lock() + redis.set_dict('queued_ip_count', {}) def put(self, item, priority): event = DataEvent() with self._cv: # Check if the IP is already in the dictionary and if it has reached the limit - if item[1] in self._ip_count and self._ip_count[item[1]] >= opts.ip_in_queue_max and priority != 0: + ip_count = redis.get_dict('queued_ip_count') + if item[1] in ip_count and ip_count[item[1]] >= opts.ip_in_queue_max and priority != 0: return None # reject the request heapq.heappush(self._queue, (-priority, self._index, item, event)) self._index += 1 # Increment the count for this IP - self._ip_count[item[1]] = self._ip_count.get(item[1], 0) + 1 + with self._lock: + increment_ip_count(item[1], 'queued_ip_count') self._cv.notify() return event @@ -37,9 +56,8 @@ class PriorityQueue: self._cv.wait() _, _, item, event = heapq.heappop(self._queue) # Decrement the count for this IP - self._ip_count[item[1]] -= 1 - if self._ip_count[item[1]] == 0: - del self._ip_count[item[1]] # Remove the IP from the dictionary if count is 0 + with self._lock: + decrement_ip_count(item[1], 'queued_ip_count') return item, event def __len__(self): @@ -60,13 +78,15 @@ def worker(): while True: (request_json_body, client_ip, token, parameters), event = priority_queue.get() - redis.sadd('processing_ips', client_ip) + # redis.sadd('processing_ips', client_ip) + increment_ip_count(client_ip, 'processing_ips') + redis.incr('active_gen_workers') start_time = time.time() success, response, error_msg = generator(request_json_body) - end_time = time.time() + elapsed_time = end_time - start_time with generation_elapsed_lock: generation_elapsed.append((end_time, elapsed_time)) @@ -74,7 +94,8 @@ def worker(): event.data = (success, response, error_msg) event.set() - redis.srem('processing_ips', client_ip) + # redis.srem('processing_ips', client_ip) + decrement_ip_count(client_ip, 'processing_ips') redis.decr('active_gen_workers') diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py index f1526c0..01d0220 100644 --- a/llm_server/routes/v1/generate.py +++ b/llm_server/routes/v1/generate.py @@ -11,11 +11,13 @@ from ..helpers.http import validate_json from ..queue import priority_queue from ... import opts from ...database import log_prompt -from ...helpers import safe_list_get +from ...helpers import safe_list_get, indefinite_article DEFAULT_PRIORITY = 9999 +# TODO: clean this up and make the ooba vs hf-textgen more object-oriented + @bp.route('/generate', methods=['POST']) def generate(): start_time = time.time() @@ -51,23 +53,21 @@ def generate(): else: print(f'Token {token} was given priority {priority}.') - if not redis.sismember('processing_ips', client_ip) or priority == 0: + queued_ip_count = redis.get_dict('queued_ip_count').get(client_ip, 0) + redis.get_dict('processing_ips').get(client_ip, 0) + if queued_ip_count < opts.ip_in_queue_max or priority == 0: event = priority_queue.put((request_json_body, client_ip, token, parameters), priority) else: event = None if not event: log_prompt(client_ip, token, request_json_body['prompt'], '', None, parameters, dict(request.headers), 429) - if opts.mode == 'oobabooga': - backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.ip_in_queue_max} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') - response_json_body = { - 'results': [ - { - 'text': backend_response, - } - ], - } - else: - raise Exception + backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.ip_in_queue_max} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') + response_json_body = { + 'results': [ + { + 'text': backend_response, + } + ], + } return jsonify({ **response_json_body }), 200 @@ -75,26 +75,19 @@ def generate(): event.wait() success, response, error_msg = event.data - # Add the elapsed time to a global list end_time = time.time() elapsed_time = end_time - start_time - # print('elapsed:', elapsed_time) - # with wait_in_queue_elapsed_lock: - # wait_in_queue_elapsed.append((end_time, elapsed_time)) - - if not success or not response: - if opts.mode == 'oobabooga': - backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error') - response_json_body = { - 'results': [ - { - 'text': backend_response, - } - ], - } - else: - raise Exception + if (not success or not response) and opts.mode == 'oobabooga': + # Ooba doesn't return any error messages + backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error') + response_json_body = { + 'results': [ + { + 'text': backend_response, + } + ], + } log_prompt(client_ip, token, request_json_body['prompt'], '', None, parameters, dict(request.headers), response if response else 0) return jsonify({ 'code': 500, @@ -103,23 +96,47 @@ def generate(): }), 200 response_valid_json, response_json_body = validate_json(response) backend_err = False + + # Return the result to the client if response_valid_json: - redis.incr('proompts') - backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text') - if not backend_response: - if opts.mode == 'oobabooga': + if opts.mode == 'oobabooga': + backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text') + if not backend_response: backend_err = True backend_response = format_sillytavern_err( f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.', 'error') response_json_body['results'][0]['text'] = backend_response + elif opts.mode == 'hf-textgen': + backend_response = response_json_body.get('generated_text', '') + if response_json_body.get('error'): + error_type = response_json_body.get('error_type') + error_type_string = 'returned an error' if opts.mode == 'oobabooga' else f'returned {indefinite_article(error_type)} {error_type} error' + response_json_body = { + 'results': [ + { + 'text': format_sillytavern_err( + f'Backend ({opts.mode}) {error_type_string}: {response_json_body.get("error")}', + 'error') + } + ] + } else: - raise Exception - - log_prompt(client_ip, token, request_json_body['prompt'], backend_response if not backend_err else '', elapsed_time if not backend_err else None, parameters, dict(request.headers), response.status_code) + response_json_body = { + 'results': [ + { + 'text': backend_response + } + ] + } + else: + raise Exception + redis.incr('proompts') + log_prompt(client_ip, token, request_json_body['prompt'], backend_response if not backend_err else '', elapsed_time if not backend_err else None, parameters, dict(request.headers), response.status_code if response else 0, response_json_body.get('details', {}).get('generated_tokens')) return jsonify({ **response_json_body }), 200 + else: if opts.mode == 'oobabooga': backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error') diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index fd445e3..23dfc4f 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -92,14 +92,15 @@ def generate_stats(): 'config': { 'gatekeeper': 'none' if opts.auth_required is False else 'token', 'context_size': opts.context_size, - 'queue_size': opts.concurrent_gens, + 'concurrent': opts.concurrent_gens, 'model': model_name, 'mode': opts.mode, - 'simultaneous_requests': opts.ip_in_queue_max, + 'simultaneous_requests_per_ip': opts.ip_in_queue_max, }, 'keys': { 'openaiKeys': '∞', 'anthropicKeys': '∞', }, + 'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None, } return deep_sort(output) diff --git a/llm_server/threads.py b/llm_server/threads.py index 2a1df2e..dbfb779 100644 --- a/llm_server/threads.py +++ b/llm_server/threads.py @@ -4,7 +4,7 @@ from threading import Thread import requests from llm_server import opts -from llm_server.database import average_column_for_model, weighted_average_column_for_model +from llm_server.database import weighted_average_column_for_model from llm_server.routes.cache import redis @@ -21,6 +21,7 @@ class MainBackgroundThread(Thread): redis.set('average_tps', 0) redis.set('average_output_tokens', 0) redis.set('backend_online', 0) + redis.set_dict('backend_info', {}) def run(self): while True: @@ -34,7 +35,16 @@ class MainBackgroundThread(Thread): # TODO: handle error print(e) elif opts.mode == 'hf-textgen': - pass + try: + r = requests.get(f'{opts.backend_url}/info', timeout=3, verify=opts.verify_ssl) + j = r.json() + opts.running_model = j['model_id'] + redis.set('backend_online', 1) + redis.set_dict('backend_info', j) + except Exception as e: + redis.set('backend_online', 0) + # TODO: handle error + print(e) else: raise Exception diff --git a/server.py b/server.py index 613442d..8e5d04c 100644 --- a/server.py +++ b/server.py @@ -53,6 +53,7 @@ opts.backend_url = config['backend_url'].strip('/') opts.show_total_output_tokens = config['show_total_output_tokens'] opts.netdata_root = config['netdata_root'] opts.ip_in_queue_max = config['ip_in_queue_max'] +opts.show_backend_info = config['show_backend_info'] opts.verify_ssl = config['verify_ssl'] if not opts.verify_ssl: