implement streaming for hf-textgen

This commit is contained in:
Cyberes 2023-08-29 17:56:12 -06:00
parent 26b04f364c
commit bf648f605f
11 changed files with 116 additions and 14 deletions

View File

@ -52,3 +52,4 @@ should probably clear the `generation_time` time column in the `prompts` table.
- Convince Oobabooga to implement concurrent generation - Convince Oobabooga to implement concurrent generation
- Make sure stats work when starting from an empty database - Make sure stats work when starting from an empty database
- Make sure we're correctly canceling requests when the client cancels - Make sure we're correctly canceling requests when the client cancels
- Implement auth and tokens on the websocket endpoint. Maybe add something to the instruct prompt and the remove it before proxying??

View File

@ -19,8 +19,8 @@ config_default_vars = {
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name'] config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
mode_ui_names = { mode_ui_names = {
'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url'), 'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
'hf-textgen': ('Text Gen WebUI (ooba)', 'Blocking API url'), 'hf-textgen': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
} }

View File

@ -50,7 +50,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
if not is_error: if not is_error:
if not response_tokens: if not response_tokens:
response_tokens = len(tokenizer.encode(response)) response_tokens = len(tokenizer.encode(response, disallowed_special=()))
else: else:
response_tokens = None response_tokens = None

View File

@ -11,7 +11,7 @@ database_path = './proxy-server.db'
auth_required = False auth_required = False
log_prompts = False log_prompts = False
frontend_api_client = '' frontend_api_client = ''
full_client_api = None base_client_api = None
http_host = None http_host = None
verify_ssl = True verify_ssl = True
show_num_prompts = True show_num_prompts = True

View File

@ -6,18 +6,16 @@ from ... import opts
bp = Blueprint('v1', __name__) bp = Blueprint('v1', __name__)
# openai_bp = Blueprint('/v1', __name__)
@bp.before_request @bp.before_request
def before_request(): def before_request():
if not opts.http_host: if not opts.http_host:
opts.http_host = request.headers.get("Host") opts.http_host = request.headers.get("Host")
if not opts.full_client_api: if not opts.base_client_api:
opts.full_client_api = f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}' opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
if request.endpoint != 'v1.get_stats': if request.endpoint != 'v1.get_stats':
response = require_api_key() response = require_api_key()
if response is not None: if response is not None:
return response return response
from . import generate, info, proxy from . import generate, info, proxy, generate_stream

View File

@ -81,7 +81,8 @@ def generate_stats():
}, },
'online': online, 'online': online,
'endpoints': { 'endpoints': {
'blocking': opts.full_client_api, 'blocking': f'https://{opts.base_client_api}',
'streaming': f'wss://{opts.base_client_api}/v1/stream',
}, },
'queue': { 'queue': {
'processing': active_gen_workers, 'processing': active_gen_workers,

View File

@ -0,0 +1,84 @@
import json
import time
import requests
from flask import request
from ..helpers.client import format_sillytavern_err
from ... import opts
from ...database import log_prompt
from ...helpers import indefinite_article
from ...llm.hf_textgen.generate import prepare_json
from ...stream import sock
@sock.route('/api/v1/stream') # TODO: use blueprint route???
def stream(ws):
start_time = time.time()
if request.headers.get('cf-connecting-ip'):
client_ip = request.headers.get('cf-connecting-ip')
elif request.headers.get('x-forwarded-for'):
client_ip = request.headers.get('x-forwarded-for').split(',')[0]
else:
client_ip = request.remote_addr
token = request.headers.get('X-Api-Key')
message_num = 0
while ws.connected:
message = ws.receive()
data = json.loads(message)
if opts.mode == 'hf-textgen':
response = requests.post(f'{opts.backend_url}/generate_stream', json=prepare_json(data), stream=True, verify=False)
# Be extra careful when getting attributes from the response object
try:
response_status_code = response.status_code
except:
response_status_code = 0
details = {}
generated_text = ''
# Iterate over each line in the response
for line in response.iter_lines():
# Decode the line to a string
line = line.decode('utf-8')
# If the line starts with 'data:', remove the prefix and parse the remaining string as JSON
if line.startswith('data:'):
line = line[5:]
json_data = json.loads(line)
details = json_data.get('details', {})
generated_text = json_data.get('generated_text', '')
if json_data.get('error'):
error_type = json_data.get('error_type')
error_type_string = 'returned an error' if opts.mode == 'oobabooga' else f'returned {indefinite_article(error_type)} {error_type} error'
generated_text = format_sillytavern_err(
f'Backend ({opts.mode}) {error_type_string}: {json_data.get("error")}',
f'HTTP CODE {response_status_code}')
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': generated_text
}))
break
else:
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': json_data['token']['text']
}))
message_num += 1
ws.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
end_time = time.time()
elapsed_time = end_time - start_time
parameters = data.copy()
del parameters['prompt']
log_prompt(client_ip, token, data['prompt'], generated_text, elapsed_time, parameters, dict(request.headers), response_status_code, response_tokens=details['generated_tokens'])

8
llm_server/stream.py Normal file
View File

@ -0,0 +1,8 @@
from flask_sock import Sock
sock = Sock()
def init_socketio(app):
global sock
sock.init_app(app)

View File

@ -7,4 +7,5 @@ tiktoken
gunicorn gunicorn
redis redis
gevent gevent
async-timeout async-timeout
flask-socketio

View File

@ -17,6 +17,7 @@ from llm_server.routes.queue import start_workers
from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time
from llm_server.routes.v1 import bp from llm_server.routes.v1 import bp
from llm_server.routes.v1.generate_stats import generate_stats from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.stream import init_socketio
from llm_server.threads import MainBackgroundThread from llm_server.threads import MainBackgroundThread
script_path = os.path.dirname(os.path.realpath(__file__)) script_path = os.path.dirname(os.path.realpath(__file__))
@ -88,6 +89,7 @@ SemaphoreCheckerThread().start()
app = Flask(__name__) app = Flask(__name__)
cache.init_app(app) cache.init_app(app)
cache.clear() # clear redis cache cache.clear() # clear redis cache
init_socketio(app)
# with app.app_context(): # with app.app_context():
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base") # current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
app.register_blueprint(bp, url_prefix='/api/v1/') app.register_blueprint(bp, url_prefix='/api/v1/')
@ -100,8 +102,8 @@ app.register_blueprint(bp, url_prefix='/api/v1/')
@app.route('/api') @app.route('/api')
@cache.cached(timeout=10, query_string=True) @cache.cached(timeout=10, query_string=True)
def home(): def home():
if not opts.full_client_api: if not opts.base_client_api:
opts.full_client_api = f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}' opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
stats = generate_stats() stats = generate_stats()
if not bool(redis.get('backend_online')) or not stats['online']: if not bool(redis.get('backend_online')) or not stats['online']:
@ -131,10 +133,12 @@ def home():
analytics_tracking_code=analytics_tracking_code, analytics_tracking_code=analytics_tracking_code,
info_html=info_html, info_html=info_html,
current_model=running_model, current_model=running_model,
client_api=opts.full_client_api, client_api=f'https://{opts.base_client_api}',
ws_client_api=f'wss://{opts.base_client_api}/v1/stream',
estimated_wait=estimated_wait_sec, estimated_wait=estimated_wait_sec,
mode_name=mode_ui_names[opts.mode][0], mode_name=mode_ui_names[opts.mode][0],
api_input_textbox=mode_ui_names[opts.mode][1], api_input_textbox=mode_ui_names[opts.mode][1],
streaming_input_textbox=mode_ui_names[opts.mode][2],
context_size=opts.context_size, context_size=opts.context_size,
stats_json=json.dumps(stats, indent=4, ensure_ascii=False), stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
extra_info=hf_textget_info if opts.mode == 'hf-textgen' else '', extra_info=hf_textget_info if opts.mode == 'hf-textgen' else '',

View File

@ -71,6 +71,7 @@
<div class="info-box"> <div class="info-box">
<p><strong>Current Model:</strong> <span id="model">{{ current_model }}</span></p> <p><strong>Current Model:</strong> <span id="model">{{ current_model }}</span></p>
<p><strong>Client API URL:</strong> {{ client_api }}</p> <p><strong>Client API URL:</strong> {{ client_api }}</p>
<p><strong>Streaming API URL:</strong> {{ ws_client_api }}</p>
<p><strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ estimated_wait }}</span></p> <p><strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ estimated_wait }}</span></p>
{{ info_html|safe }} {{ info_html|safe }}
</div> </div>
@ -83,6 +84,10 @@
<ol> <ol>
<li>Set your API type to <kbd>{{ mode_name }}</kbd></li> <li>Set your API type to <kbd>{{ mode_name }}</kbd></li>
<li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li> <li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li>
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.</li>
<li>If using a token, check the <kbd>Mancer AI</kbd> checkbox and enter your token in the <kbd>Mancer
API key</kbd> textbox.
</li>
<li>Click <kbd>Connect</kbd> to test the connection.</li> <li>Click <kbd>Connect</kbd> to test the connection.</li>
<li>Open your preset config and set <kbd>Context Size</kbd> to {{ context_size }}.</li> <li>Open your preset config and set <kbd>Context Size</kbd> to {{ context_size }}.</li>
<li>Follow this guide to get set up: <a href="https://rentry.org/freellamas" target="_blank">rentry.org/freellamas</a> <li>Follow this guide to get set up: <a href="https://rentry.org/freellamas" target="_blank">rentry.org/freellamas</a>