From bf648f605fcb0ec0bdbeb6df422f8791cf64923a Mon Sep 17 00:00:00 2001 From: Cyberes Date: Tue, 29 Aug 2023 17:56:12 -0600 Subject: [PATCH] implement streaming for hf-textgen --- README.md | 1 + llm_server/config.py | 4 +- llm_server/database.py | 2 +- llm_server/opts.py | 2 +- llm_server/routes/v1/__init__.py | 8 +-- llm_server/routes/v1/generate_stats.py | 3 +- llm_server/routes/v1/generate_stream.py | 84 +++++++++++++++++++++++++ llm_server/stream.py | 8 +++ requirements.txt | 3 +- server.py | 10 ++- templates/home.html | 5 ++ 11 files changed, 116 insertions(+), 14 deletions(-) create mode 100644 llm_server/routes/v1/generate_stream.py create mode 100644 llm_server/stream.py diff --git a/README.md b/README.md index e71c625..b6abb5f 100644 --- a/README.md +++ b/README.md @@ -52,3 +52,4 @@ should probably clear the `generation_time` time column in the `prompts` table. - Convince Oobabooga to implement concurrent generation - Make sure stats work when starting from an empty database - Make sure we're correctly canceling requests when the client cancels +- Implement auth and tokens on the websocket endpoint. Maybe add something to the instruct prompt and the remove it before proxying?? \ No newline at end of file diff --git a/llm_server/config.py b/llm_server/config.py index 7b53bef..b47f8bc 100644 --- a/llm_server/config.py +++ b/llm_server/config.py @@ -19,8 +19,8 @@ config_default_vars = { config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name'] mode_ui_names = { - 'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url'), - 'hf-textgen': ('Text Gen WebUI (ooba)', 'Blocking API url'), + 'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), + 'hf-textgen': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), } diff --git a/llm_server/database.py b/llm_server/database.py index d31e482..e0bb1a5 100644 --- a/llm_server/database.py +++ b/llm_server/database.py @@ -50,7 +50,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe if not is_error: if not response_tokens: - response_tokens = len(tokenizer.encode(response)) + response_tokens = len(tokenizer.encode(response, disallowed_special=())) else: response_tokens = None diff --git a/llm_server/opts.py b/llm_server/opts.py index bfa0b72..645ea23 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -11,7 +11,7 @@ database_path = './proxy-server.db' auth_required = False log_prompts = False frontend_api_client = '' -full_client_api = None +base_client_api = None http_host = None verify_ssl = True show_num_prompts = True diff --git a/llm_server/routes/v1/__init__.py b/llm_server/routes/v1/__init__.py index 5a5715d..25ab2f6 100644 --- a/llm_server/routes/v1/__init__.py +++ b/llm_server/routes/v1/__init__.py @@ -6,18 +6,16 @@ from ... import opts bp = Blueprint('v1', __name__) -# openai_bp = Blueprint('/v1', __name__) - @bp.before_request def before_request(): if not opts.http_host: opts.http_host = request.headers.get("Host") - if not opts.full_client_api: - opts.full_client_api = f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}' + if not opts.base_client_api: + opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}' if request.endpoint != 'v1.get_stats': response = require_api_key() if response is not None: return response -from . import generate, info, proxy +from . import generate, info, proxy, generate_stream diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index 8f37f6f..acc28f0 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -81,7 +81,8 @@ def generate_stats(): }, 'online': online, 'endpoints': { - 'blocking': opts.full_client_api, + 'blocking': f'https://{opts.base_client_api}', + 'streaming': f'wss://{opts.base_client_api}/v1/stream', }, 'queue': { 'processing': active_gen_workers, diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py new file mode 100644 index 0000000..5b992e7 --- /dev/null +++ b/llm_server/routes/v1/generate_stream.py @@ -0,0 +1,84 @@ +import json +import time + +import requests +from flask import request + +from ..helpers.client import format_sillytavern_err +from ... import opts +from ...database import log_prompt +from ...helpers import indefinite_article +from ...llm.hf_textgen.generate import prepare_json +from ...stream import sock + + +@sock.route('/api/v1/stream') # TODO: use blueprint route??? +def stream(ws): + start_time = time.time() + if request.headers.get('cf-connecting-ip'): + client_ip = request.headers.get('cf-connecting-ip') + elif request.headers.get('x-forwarded-for'): + client_ip = request.headers.get('x-forwarded-for').split(',')[0] + else: + client_ip = request.remote_addr + token = request.headers.get('X-Api-Key') + + message_num = 0 + while ws.connected: + message = ws.receive() + data = json.loads(message) + + if opts.mode == 'hf-textgen': + response = requests.post(f'{opts.backend_url}/generate_stream', json=prepare_json(data), stream=True, verify=False) + + # Be extra careful when getting attributes from the response object + try: + response_status_code = response.status_code + except: + response_status_code = 0 + + details = {} + generated_text = '' + + # Iterate over each line in the response + for line in response.iter_lines(): + # Decode the line to a string + line = line.decode('utf-8') + # If the line starts with 'data:', remove the prefix and parse the remaining string as JSON + if line.startswith('data:'): + line = line[5:] + json_data = json.loads(line) + details = json_data.get('details', {}) + generated_text = json_data.get('generated_text', '') + + if json_data.get('error'): + error_type = json_data.get('error_type') + error_type_string = 'returned an error' if opts.mode == 'oobabooga' else f'returned {indefinite_article(error_type)} {error_type} error' + generated_text = format_sillytavern_err( + f'Backend ({opts.mode}) {error_type_string}: {json_data.get("error")}', + f'HTTP CODE {response_status_code}') + ws.send(json.dumps({ + 'event': 'text_stream', + 'message_num': message_num, + 'text': generated_text + })) + break + else: + ws.send(json.dumps({ + 'event': 'text_stream', + 'message_num': message_num, + 'text': json_data['token']['text'] + })) + message_num += 1 + + ws.send(json.dumps({ + 'event': 'stream_end', + 'message_num': message_num + })) + + end_time = time.time() + elapsed_time = end_time - start_time + parameters = data.copy() + del parameters['prompt'] + + log_prompt(client_ip, token, data['prompt'], generated_text, elapsed_time, parameters, dict(request.headers), response_status_code, response_tokens=details['generated_tokens']) diff --git a/llm_server/stream.py b/llm_server/stream.py new file mode 100644 index 0000000..8ac2fc1 --- /dev/null +++ b/llm_server/stream.py @@ -0,0 +1,8 @@ +from flask_sock import Sock + +sock = Sock() + + +def init_socketio(app): + global sock + sock.init_app(app) diff --git a/requirements.txt b/requirements.txt index 9c24864..6a78098 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ tiktoken gunicorn redis gevent -async-timeout \ No newline at end of file +async-timeout +flask-socketio diff --git a/server.py b/server.py index 8182953..cb3804d 100644 --- a/server.py +++ b/server.py @@ -17,6 +17,7 @@ from llm_server.routes.queue import start_workers from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time from llm_server.routes.v1 import bp from llm_server.routes.v1.generate_stats import generate_stats +from llm_server.stream import init_socketio from llm_server.threads import MainBackgroundThread script_path = os.path.dirname(os.path.realpath(__file__)) @@ -88,6 +89,7 @@ SemaphoreCheckerThread().start() app = Flask(__name__) cache.init_app(app) cache.clear() # clear redis cache +init_socketio(app) # with app.app_context(): # current_app.tokenizer = tiktoken.get_encoding("cl100k_base") app.register_blueprint(bp, url_prefix='/api/v1/') @@ -100,8 +102,8 @@ app.register_blueprint(bp, url_prefix='/api/v1/') @app.route('/api') @cache.cached(timeout=10, query_string=True) def home(): - if not opts.full_client_api: - opts.full_client_api = f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}' + if not opts.base_client_api: + opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}' stats = generate_stats() if not bool(redis.get('backend_online')) or not stats['online']: @@ -131,10 +133,12 @@ def home(): analytics_tracking_code=analytics_tracking_code, info_html=info_html, current_model=running_model, - client_api=opts.full_client_api, + client_api=f'https://{opts.base_client_api}', + ws_client_api=f'wss://{opts.base_client_api}/v1/stream', estimated_wait=estimated_wait_sec, mode_name=mode_ui_names[opts.mode][0], api_input_textbox=mode_ui_names[opts.mode][1], + streaming_input_textbox=mode_ui_names[opts.mode][2], context_size=opts.context_size, stats_json=json.dumps(stats, indent=4, ensure_ascii=False), extra_info=hf_textget_info if opts.mode == 'hf-textgen' else '', diff --git a/templates/home.html b/templates/home.html index e7167d6..519ce789 100644 --- a/templates/home.html +++ b/templates/home.html @@ -71,6 +71,7 @@

Current Model: {{ current_model }}

Client API URL: {{ client_api }}

+

Streaming API URL: {{ ws_client_api }}

Estimated Wait Time: {{ estimated_wait }}

{{ info_html|safe }}
@@ -83,6 +84,10 @@
  1. Set your API type to {{ mode_name }}
  2. Enter {{ client_api }} in the {{ api_input_textbox }} textbox.
  3. +
  4. Enter {{ ws_client_api }} in the {{ streaming_input_textbox }} textbox.
  5. +
  6. If using a token, check the Mancer AI checkbox and enter your token in the Mancer + API key textbox. +
  7. Click Connect to test the connection.
  8. Open your preset config and set Context Size to {{ context_size }}.
  9. Follow this guide to get set up: rentry.org/freellamas