implement streaming for hf-textgen

This commit is contained in:
Cyberes 2023-08-29 17:56:12 -06:00
parent 26b04f364c
commit bf648f605f
11 changed files with 116 additions and 14 deletions

View File

@ -52,3 +52,4 @@ should probably clear the `generation_time` time column in the `prompts` table.
- Convince Oobabooga to implement concurrent generation
- Make sure stats work when starting from an empty database
- Make sure we're correctly canceling requests when the client cancels
- Implement auth and tokens on the websocket endpoint. Maybe add something to the instruct prompt and the remove it before proxying??

View File

@ -19,8 +19,8 @@ config_default_vars = {
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
mode_ui_names = {
'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url'),
'hf-textgen': ('Text Gen WebUI (ooba)', 'Blocking API url'),
'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
'hf-textgen': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
}

View File

@ -50,7 +50,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
if not is_error:
if not response_tokens:
response_tokens = len(tokenizer.encode(response))
response_tokens = len(tokenizer.encode(response, disallowed_special=()))
else:
response_tokens = None

View File

@ -11,7 +11,7 @@ database_path = './proxy-server.db'
auth_required = False
log_prompts = False
frontend_api_client = ''
full_client_api = None
base_client_api = None
http_host = None
verify_ssl = True
show_num_prompts = True

View File

@ -6,18 +6,16 @@ from ... import opts
bp = Blueprint('v1', __name__)
# openai_bp = Blueprint('/v1', __name__)
@bp.before_request
def before_request():
if not opts.http_host:
opts.http_host = request.headers.get("Host")
if not opts.full_client_api:
opts.full_client_api = f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
if not opts.base_client_api:
opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
if request.endpoint != 'v1.get_stats':
response = require_api_key()
if response is not None:
return response
from . import generate, info, proxy
from . import generate, info, proxy, generate_stream

View File

@ -81,7 +81,8 @@ def generate_stats():
},
'online': online,
'endpoints': {
'blocking': opts.full_client_api,
'blocking': f'https://{opts.base_client_api}',
'streaming': f'wss://{opts.base_client_api}/v1/stream',
},
'queue': {
'processing': active_gen_workers,

View File

@ -0,0 +1,84 @@
import json
import time
import requests
from flask import request
from ..helpers.client import format_sillytavern_err
from ... import opts
from ...database import log_prompt
from ...helpers import indefinite_article
from ...llm.hf_textgen.generate import prepare_json
from ...stream import sock
@sock.route('/api/v1/stream') # TODO: use blueprint route???
def stream(ws):
start_time = time.time()
if request.headers.get('cf-connecting-ip'):
client_ip = request.headers.get('cf-connecting-ip')
elif request.headers.get('x-forwarded-for'):
client_ip = request.headers.get('x-forwarded-for').split(',')[0]
else:
client_ip = request.remote_addr
token = request.headers.get('X-Api-Key')
message_num = 0
while ws.connected:
message = ws.receive()
data = json.loads(message)
if opts.mode == 'hf-textgen':
response = requests.post(f'{opts.backend_url}/generate_stream', json=prepare_json(data), stream=True, verify=False)
# Be extra careful when getting attributes from the response object
try:
response_status_code = response.status_code
except:
response_status_code = 0
details = {}
generated_text = ''
# Iterate over each line in the response
for line in response.iter_lines():
# Decode the line to a string
line = line.decode('utf-8')
# If the line starts with 'data:', remove the prefix and parse the remaining string as JSON
if line.startswith('data:'):
line = line[5:]
json_data = json.loads(line)
details = json_data.get('details', {})
generated_text = json_data.get('generated_text', '')
if json_data.get('error'):
error_type = json_data.get('error_type')
error_type_string = 'returned an error' if opts.mode == 'oobabooga' else f'returned {indefinite_article(error_type)} {error_type} error'
generated_text = format_sillytavern_err(
f'Backend ({opts.mode}) {error_type_string}: {json_data.get("error")}',
f'HTTP CODE {response_status_code}')
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': generated_text
}))
break
else:
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': json_data['token']['text']
}))
message_num += 1
ws.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
end_time = time.time()
elapsed_time = end_time - start_time
parameters = data.copy()
del parameters['prompt']
log_prompt(client_ip, token, data['prompt'], generated_text, elapsed_time, parameters, dict(request.headers), response_status_code, response_tokens=details['generated_tokens'])

8
llm_server/stream.py Normal file
View File

@ -0,0 +1,8 @@
from flask_sock import Sock
sock = Sock()
def init_socketio(app):
global sock
sock.init_app(app)

View File

@ -8,3 +8,4 @@ gunicorn
redis
gevent
async-timeout
flask-socketio

View File

@ -17,6 +17,7 @@ from llm_server.routes.queue import start_workers
from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time
from llm_server.routes.v1 import bp
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.stream import init_socketio
from llm_server.threads import MainBackgroundThread
script_path = os.path.dirname(os.path.realpath(__file__))
@ -88,6 +89,7 @@ SemaphoreCheckerThread().start()
app = Flask(__name__)
cache.init_app(app)
cache.clear() # clear redis cache
init_socketio(app)
# with app.app_context():
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
app.register_blueprint(bp, url_prefix='/api/v1/')
@ -100,8 +102,8 @@ app.register_blueprint(bp, url_prefix='/api/v1/')
@app.route('/api')
@cache.cached(timeout=10, query_string=True)
def home():
if not opts.full_client_api:
opts.full_client_api = f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
if not opts.base_client_api:
opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
stats = generate_stats()
if not bool(redis.get('backend_online')) or not stats['online']:
@ -131,10 +133,12 @@ def home():
analytics_tracking_code=analytics_tracking_code,
info_html=info_html,
current_model=running_model,
client_api=opts.full_client_api,
client_api=f'https://{opts.base_client_api}',
ws_client_api=f'wss://{opts.base_client_api}/v1/stream',
estimated_wait=estimated_wait_sec,
mode_name=mode_ui_names[opts.mode][0],
api_input_textbox=mode_ui_names[opts.mode][1],
streaming_input_textbox=mode_ui_names[opts.mode][2],
context_size=opts.context_size,
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
extra_info=hf_textget_info if opts.mode == 'hf-textgen' else '',

View File

@ -71,6 +71,7 @@
<div class="info-box">
<p><strong>Current Model:</strong> <span id="model">{{ current_model }}</span></p>
<p><strong>Client API URL:</strong> {{ client_api }}</p>
<p><strong>Streaming API URL:</strong> {{ ws_client_api }}</p>
<p><strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ estimated_wait }}</span></p>
{{ info_html|safe }}
</div>
@ -83,6 +84,10 @@
<ol>
<li>Set your API type to <kbd>{{ mode_name }}</kbd></li>
<li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li>
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.</li>
<li>If using a token, check the <kbd>Mancer AI</kbd> checkbox and enter your token in the <kbd>Mancer
API key</kbd> textbox.
</li>
<li>Click <kbd>Connect</kbd> to test the connection.</li>
<li>Open your preset config and set <kbd>Context Size</kbd> to {{ context_size }}.</li>
<li>Follow this guide to get set up: <a href="https://rentry.org/freellamas" target="_blank">rentry.org/freellamas</a>