implement streaming for vllm
This commit is contained in:
parent
81452ec643
commit
76a1428ba0
|
@ -91,6 +91,6 @@ def handle_blocking_request(json_data: dict):
|
|||
|
||||
def generate(json_data: dict):
|
||||
if json_data.get('stream'):
|
||||
raise Exception('streaming not implemented')
|
||||
return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
else:
|
||||
return handle_blocking_request(json_data)
|
||||
|
|
|
@ -39,7 +39,9 @@ def require_api_key():
|
|||
return jsonify({'code': 401, 'message': 'API key required'}), 401
|
||||
|
||||
|
||||
def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response]):
|
||||
def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response, dict]):
|
||||
if isinstance(data, dict):
|
||||
return True, data
|
||||
try:
|
||||
if isinstance(data, (Request, flask.Response)):
|
||||
data = data.json
|
||||
|
|
|
@ -18,9 +18,17 @@ DEFAULT_PRIORITY = 9999
|
|||
|
||||
|
||||
class RequestHandler:
|
||||
def __init__(self, incoming_request: flask.Request):
|
||||
def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None):
|
||||
self.request = incoming_request
|
||||
_, self.request_json_body = validate_json(self.request) # routes need to validate it, here we just load it
|
||||
|
||||
# routes need to validate it, here we just load it
|
||||
if incoming_json:
|
||||
self.request_valid_json, self.request_json_body = validate_json(incoming_json)
|
||||
else:
|
||||
self.request_valid_json, self.request_json_body = validate_json(self.request)
|
||||
if not self.request_valid_json:
|
||||
raise Exception(f'Not valid JSON')
|
||||
|
||||
self.start_time = time.time()
|
||||
self.client_ip = self.get_client_ip()
|
||||
self.token = self.request.headers.get('X-Api-Key')
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import json
|
||||
import time
|
||||
|
||||
import requests
|
||||
from flask import request
|
||||
|
||||
from ..helpers.client import format_sillytavern_err
|
||||
from ..helpers.http import validate_json
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ... import opts
|
||||
from ...database.database import log_prompt
|
||||
from ...helpers import indefinite_article
|
||||
from ...llm.generator import generator
|
||||
from ...llm.vllm import tokenize
|
||||
from ...stream import sock
|
||||
|
||||
|
||||
|
@ -19,71 +20,97 @@ def stream(ws):
|
|||
# TODO: return a formatted ST error message
|
||||
return 'disabled', 401
|
||||
|
||||
# start_time = time.time()
|
||||
# if request.headers.get('cf-connecting-ip'):
|
||||
# client_ip = request.headers.get('cf-connecting-ip')
|
||||
# elif request.headers.get('x-forwarded-for'):
|
||||
# client_ip = request.headers.get('x-forwarded-for').split(',')[0]
|
||||
# else:
|
||||
# client_ip = request.remote_addr
|
||||
# token = request.headers.get('X-Api-Key')
|
||||
#
|
||||
# message_num = 0
|
||||
# while ws.connected:
|
||||
# message = ws.receive()
|
||||
# data = json.loads(message)
|
||||
#
|
||||
# if opts.mode == 'hf-textgen':
|
||||
# response = requests.post(f'{opts.backend_url}/generate_stream', json=prepare_json(data), stream=True, verify=False)
|
||||
#
|
||||
# # Be extra careful when getting attributes from the response object
|
||||
# try:
|
||||
# response_status_code = response.status_code
|
||||
# except:
|
||||
# response_status_code = 0
|
||||
#
|
||||
# details = {}
|
||||
# generated_text = ''
|
||||
#
|
||||
# # Iterate over each line in the response
|
||||
# for line in response.iter_lines():
|
||||
# # Decode the line to a string
|
||||
# line = line.decode('utf-8')
|
||||
# # If the line starts with 'data:', remove the prefix and parse the remaining string as JSON
|
||||
# if line.startswith('data:'):
|
||||
# line = line[5:]
|
||||
# json_data = json.loads(line)
|
||||
# details = json_data.get('details', {})
|
||||
# generated_text = json_data.get('generated_text', '')
|
||||
#
|
||||
# if json_data.get('error'):
|
||||
# error_type = json_data.get('error_type')
|
||||
# error_type_string = 'returned an error' if opts.mode == 'oobabooga' else f'returned {indefinite_article(error_type)} {error_type} error'
|
||||
# generated_text = format_sillytavern_err(
|
||||
# f'Backend ({opts.mode}) {error_type_string}: {json_data.get("error")}',
|
||||
# f'HTTP CODE {response_status_code}')
|
||||
# ws.send(json.dumps({
|
||||
# 'event': 'text_stream',
|
||||
# 'message_num': message_num,
|
||||
# 'text': generated_text
|
||||
# }))
|
||||
# break
|
||||
# else:
|
||||
# ws.send(json.dumps({
|
||||
# 'event': 'text_stream',
|
||||
# 'message_num': message_num,
|
||||
# 'text': json_data['token']['text']
|
||||
# }))
|
||||
# message_num += 1
|
||||
#
|
||||
# ws.send(json.dumps({
|
||||
# 'event': 'stream_end',
|
||||
# 'message_num': message_num
|
||||
# }))
|
||||
#
|
||||
# end_time = time.time()
|
||||
# elapsed_time = end_time - start_time
|
||||
# parameters = data.copy()
|
||||
# del parameters['prompt']
|
||||
#
|
||||
# log_prompt(client_ip, token, data['prompt'], generated_text, elapsed_time, parameters, dict(request.headers), response_status_code, response_tokens=details['generated_tokens'])
|
||||
message_num = 0
|
||||
while ws.connected:
|
||||
message = ws.receive()
|
||||
request_valid_json, request_json_body = validate_json(message)
|
||||
if not request_valid_json or not request_json_body.get('prompt'):
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': 'Invalid JSON'
|
||||
}))
|
||||
message_num += 1
|
||||
else:
|
||||
if opts.mode != 'vllm':
|
||||
# TODO: implement other backends
|
||||
raise NotImplementedError
|
||||
|
||||
handler = OobaRequestHandler(request, request_json_body)
|
||||
generated_text = ''
|
||||
input_prompt = None
|
||||
response_status_code = 0
|
||||
start_time = time.time()
|
||||
request_valid, invalid_response = handler.validate_request()
|
||||
if not request_valid:
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': invalid_response
|
||||
}))
|
||||
else:
|
||||
input_prompt = request_json_body['prompt']
|
||||
msg_to_backend = {
|
||||
**handler.parameters,
|
||||
'prompt': input_prompt,
|
||||
'stream': True,
|
||||
}
|
||||
response = generator(msg_to_backend)
|
||||
|
||||
# Be extra careful when getting attributes from the response object
|
||||
try:
|
||||
response_status_code = response.status_code
|
||||
except:
|
||||
response_status_code = 0
|
||||
|
||||
# details = {}
|
||||
|
||||
# Initialize an empty byte string to store parts of the response
|
||||
partial_response = b''
|
||||
|
||||
# Process each part of the response as it's received
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
# Add the chunk to the partial response
|
||||
partial_response += chunk
|
||||
|
||||
# If the partial response ends with a null character, parse it as JSON
|
||||
if partial_response.endswith(b'\x00'):
|
||||
# Remove the null character and decode the byte string to a string
|
||||
json_str = partial_response[:-1].decode()
|
||||
|
||||
# Parse the string as JSON
|
||||
json_obj = json.loads(json_str)
|
||||
|
||||
# Strip the input prompt from the response
|
||||
if generated_text:
|
||||
new = json_obj['text'][0].split(generated_text)[1]
|
||||
else:
|
||||
new = json_obj['text'][0].split(input_prompt)[1]
|
||||
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': new
|
||||
}))
|
||||
message_num += 1
|
||||
|
||||
generated_text = json_obj['text'][0]
|
||||
|
||||
# Reset the partial response
|
||||
partial_response = b''
|
||||
|
||||
# If there is no more data, break the loop
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
response.close()
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
generated_tokens = tokenize(generated_text)
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
|
||||
|
||||
ws.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
}))
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
print("Websockets package not found. Make sure it's installed.")
|
||||
|
||||
# For local streaming, the websockets are hosted without ssl - ws://
|
||||
HOST = 'localhost:5000'
|
||||
URI = f'ws://{HOST}/api/v1/stream'
|
||||
|
||||
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
|
||||
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
|
||||
|
||||
|
||||
async def run(context):
|
||||
# Note: the selected defaults change from time to time.
|
||||
request = {
|
||||
'prompt': context,
|
||||
'max_new_tokens': 250,
|
||||
'auto_max_new_tokens': False,
|
||||
'max_tokens_second': 0,
|
||||
|
||||
# Generation params. If 'preset' is set to different than 'None', the values
|
||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
||||
'preset': 'None',
|
||||
'do_sample': True,
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.1,
|
||||
'typical_p': 1,
|
||||
'epsilon_cutoff': 0, # In units of 1e-4
|
||||
'eta_cutoff': 0, # In units of 1e-4
|
||||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'repetition_penalty': 1.18,
|
||||
'repetition_penalty_range': 0,
|
||||
'top_k': 40,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False,
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5,
|
||||
'mirostat_eta': 0.1,
|
||||
'guidance_scale': 1,
|
||||
'negative_prompt': '',
|
||||
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
'truncation_length': 2048,
|
||||
'ban_eos_token': False,
|
||||
'custom_token_bans': '',
|
||||
'skip_special_tokens': True,
|
||||
'stopping_strings': []
|
||||
}
|
||||
|
||||
async with websockets.connect(URI, ping_interval=None) as websocket:
|
||||
await websocket.send(json.dumps(request))
|
||||
|
||||
yield context # Remove this if you just want to see the reply
|
||||
|
||||
while True:
|
||||
incoming_data = await websocket.recv()
|
||||
incoming_data = json.loads(incoming_data)
|
||||
|
||||
match incoming_data['event']:
|
||||
case 'text_stream':
|
||||
yield incoming_data['text']
|
||||
case 'stream_end':
|
||||
return
|
||||
|
||||
|
||||
async def print_response_stream(prompt):
|
||||
async for response in run(prompt):
|
||||
print(response, end='')
|
||||
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
|
||||
print('\n\nfinished')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
prompt = "In order to make homemade bread, follow these steps:\n1)"
|
||||
asyncio.run(print_response_stream(prompt))
|
Reference in New Issue