fix processing not being decremented on streaming, fix confusion over queue, adjust stop sequences

This commit is contained in:
Cyberes 2023-10-02 20:53:08 -06:00
parent 4f226ae38e
commit 94141b8ecf
16 changed files with 226 additions and 225 deletions

View File

@ -5,7 +5,7 @@ from llm_server.cluster.backend import get_a_cluster_backend, get_backends_from_
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers
from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers_model
# TODO: give this a better name!
@ -30,7 +30,7 @@ def get_model_choices(regen: bool = False):
if backend_info.get('average_generation_elapsed_sec'):
avg_gen_per_worker.append(backend_info['average_generation_elapsed_sec'])
active_gen_workers = get_active_gen_workers(model)
active_gen_workers = get_active_gen_workers_model(model)
proompters_in_queue = priority_queue.len(model)
if len(avg_gen_per_worker):

View File

@ -2,15 +2,15 @@ import json
import time
import traceback
from threading import Thread
from typing import Union
import llm_server
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
from llm_server.database.conn import database
from llm_server.llm.vllm import tokenize
from llm_server.llm import get_token_count
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, backend_url, response_tokens: int = None, is_error: bool = False):
def log_prompt(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
def background_task():
nonlocal ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, backend_url, response_tokens, is_error
# Try not to shove JSON into the database.
@ -23,10 +23,10 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
except:
pass
prompt_tokens = llm_server.llm.get_token_count(prompt, backend_url)
prompt_tokens = get_token_count(prompt, backend_url)
if not is_error:
if not response_tokens:
response_tokens = llm_server.llm.get_token_count(response, backend_url)
response_tokens = get_token_count(response, backend_url)
else:
response_tokens = None

View File

@ -3,6 +3,9 @@ from llm_server.custom_redis import redis
def get_token_count(prompt: str, backend_url: str):
assert isinstance(prompt, str)
assert isinstance(backend_url, str)
backend_mode = redis.get('backend_mode', dtype=str)
if backend_mode == 'vllm':
return vllm.tokenize(prompt, backend_url)

View File

@ -8,11 +8,11 @@ def oai_to_vllm(request_json_body, hashes: bool, mode):
request_json_body['stop'] = []
if hashes:
request_json_body['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
request_json_body['stop'].extend(['### INSTRUCTION', '### USER', '### ASSISTANT', '### RESPONSE'])
if opts.openai_force_no_hashes:
request_json_body['stop'].append('### ')
else:
request_json_body['stop'].extend(['\nuser:', '\nassistant:'])
request_json_body['stop'].extend(['user:', 'assistant:'])
if request_json_body.get('frequency_penalty', 0) < -2:
request_json_body['frequency_penalty'] = -2

View File

@ -8,6 +8,9 @@ from llm_server import opts
def tokenize(prompt: str, backend_url: str) -> int:
assert backend_url
assert isinstance(prompt, str)
assert isinstance(backend_url, str)
if not prompt:
# The tokenizers have issues when the prompt is None.
return 0

View File

@ -62,7 +62,7 @@ def openai_chat_completions():
}
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url)
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
if not event:
log_prompt(
handler.client_ip,

View File

@ -100,7 +100,7 @@ def openai_completions():
}
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url)
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
if not event:
log_prompt(
handler.client_ip,

View File

@ -8,11 +8,11 @@ from uuid import uuid4
import flask
from flask import Response, jsonify, make_response
import llm_server
from llm_server import opts
from llm_server.cluster.model_choices import get_model_choices
from llm_server.custom_redis import redis
from llm_server.database.database import is_api_key_moderated, log_prompt
from llm_server.llm import get_token_count
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from llm_server.routes.request_handler import RequestHandler
@ -110,9 +110,8 @@ class OpenAIRequestHandler(RequestHandler):
response = re.sub(ANTI_RESPONSE_RE, '', response)
response = re.sub(ANTI_CONTINUATION_RE, '', response)
# TODO: async/await
prompt_tokens = llm_server.llm.get_token_count(prompt, self.backend_url)
response_tokens = llm_server.llm.get_token_count(response, self.backend_url)
prompt_tokens = get_token_count(prompt, self.backend_url)
response_tokens = get_token_count(response, self.backend_url)
running_model = redis.get('running_model', 'ERROR', dtype=str)
response = make_response(jsonify({

View File

@ -27,7 +27,6 @@ class RedisPriorityQueue:
def put(self, item, priority, selected_model):
event = DataEvent()
# Check if the IP is already in the dictionary and if it has reached the limit
ip_count = self.redis.hget('queued_ip_count', item[1])
if ip_count:
@ -99,16 +98,20 @@ class DataEvent:
priority_queue = RedisPriorityQueue()
def update_active_workers(key: str, operation: str):
if operation == 'incr':
redis.incr(f'active_gen_workers:{key}')
elif operation == 'decr':
redis.decr(f'active_gen_workers:{key}')
if redis.get(f'active_gen_workers:{key}', default=0, dtype=int) < 0:
redis.set(f'active_gen_workers:{key}', 0)
def incr_active_workers(selected_model: str, backend_url: str):
redis.incr(f'active_gen_workers:{selected_model}')
redis.incr(f'active_gen_workers:{backend_url}')
update_active_workers(selected_model, 'incr')
update_active_workers(backend_url, 'incr')
def decr_active_workers(selected_model: str, backend_url: str):
redis.decr(f'active_gen_workers:{selected_model}')
if redis.get(f'active_gen_workers:{selected_model}', 0, dtype=int) < 0:
redis.set(f'active_gen_workers:{selected_model}', 0)
redis.decr(f'active_gen_workers:{backend_url}')
if redis.get(f'active_gen_workers:{backend_url}', 0, dtype=int) < 0:
redis.set(f'active_gen_workers:{backend_url}', 0)
update_active_workers(selected_model, 'decr')
update_active_workers(backend_url, 'decr')

View File

@ -36,6 +36,7 @@ class RequestHandler:
self.token_priority, self.token_simultaneous_ip = get_token_ratelimit(self.token)
self.backend_url = get_a_cluster_backend(selected_model)
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
self.selected_model = self.cluster_backend_info['model']
if not self.cluster_backend_info.get('mode'):
print(selected_model, self.backend_url, self.cluster_backend_info)
@ -43,7 +44,6 @@ class RequestHandler:
self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url)
self.parameters = None
self.used = False
self.selected_model = selected_model
redis.zadd('recent_prompters', {self.client_ip: time.time()})
def get_auth_token(self):

View File

@ -15,13 +15,8 @@ def get_total_proompts():
return count
def get_active_gen_workers(selected_model: str = None, ):
active_gen_workers = redis.get(f'active_gen_workers:{selected_model}')
if active_gen_workers is None:
count = 0
else:
count = int(active_gen_workers)
return count
def get_active_gen_workers_model(selected_model: str = None):
return redis.get(f'active_gen_workers:{selected_model}', dtype=int, default=0)
def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):

View File

@ -11,7 +11,6 @@ from ..queue import decr_active_workers, decrement_ip_count, priority_queue
from ... import opts
from ...database.database import log_prompt
from ...llm.generator import generator
from ...llm.vllm import tokenize
from ...sock import sock
@ -45,7 +44,6 @@ def do_stream(ws, model_name):
'event': 'stream_end',
'message_num': 1
}))
ws.close()
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
@ -56,7 +54,7 @@ def do_stream(ws, model_name):
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.cluster_backend_info,
response_tokens=tokenize(generated_text, handler.backend_url),
response_tokens=None,
is_error=True
)
@ -67,195 +65,192 @@ def do_stream(ws, model_name):
r_url = request.url
message_num = 0
while ws.connected:
message = ws.receive()
request_valid_json, request_json_body = validate_json(message)
try:
while ws.connected:
message = ws.receive()
request_valid_json, request_json_body = validate_json(message)
if not request_valid_json or not request_json_body.get('prompt'):
ws.close()
return 'Invalid JSON', 400
else:
if opts.mode != 'vllm':
# TODO: implement other backends
raise NotImplementedError
auth_failure = require_api_key(request_json_body)
if auth_failure:
ws.close()
return auth_failure
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
generated_text = ''
input_prompt = request_json_body['prompt']
response_status_code = 0
start_time = time.time()
err_msg = None
if handler.is_client_ratelimited():
r, _ = handler.handle_ratelimited(do_log=False)
err_msg = r.json['results'][0]['text']
if not request_valid_json or not request_json_body.get('prompt'):
return 'Invalid JSON', 400
else:
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
if not request_valid:
err_msg = invalid_response[0].json['results'][0]['text']
if err_msg:
send_err_and_quit(err_msg)
return
if opts.mode != 'vllm':
# TODO: implement other backends
raise NotImplementedError
llm_request = {
**handler.parameters,
'prompt': input_prompt,
'stream': True,
}
auth_failure = require_api_key(request_json_body)
if auth_failure:
return auth_failure
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url)
if not event:
r, _ = handler.handle_ratelimited()
err_msg = r.json['results'][0]['text']
send_err_and_quit(err_msg)
return
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
generated_text = ''
input_prompt = request_json_body['prompt']
response_status_code = 0
start_time = time.time()
# Wait for a worker to get our request and discard it.
_, _, _ = event.wait()
err_msg = None
if handler.is_client_ratelimited():
r, _ = handler.handle_ratelimited(do_log=False)
err_msg = r.json['results'][0]['text']
else:
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
if not request_valid:
err_msg = invalid_response[0].json['results'][0]['text']
if err_msg:
send_err_and_quit(err_msg)
return
try:
response = generator(llm_request, handler.backend_url)
if not response:
error_msg = 'Failed to reach backend while streaming.'
print('Streaming failed:', error_msg)
msg = handler.handle_error(error_msg)[0].json['results'][0]['text']
llm_request = {
**handler.parameters,
'prompt': input_prompt,
'stream': True,
}
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
if not event:
r, _ = handler.handle_ratelimited()
err_msg = r.json['results'][0]['text']
send_err_and_quit(err_msg)
return
# Wait for a worker to get our request and discard it.
_, _, _ = event.wait()
try:
response = generator(llm_request, handler.backend_url)
if not response:
error_msg = 'Failed to reach backend while streaming.'
print('Streaming failed:', error_msg)
msg = handler.handle_error(error_msg)[0].json['results'][0]['text']
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': msg
}))
else:
# Be extra careful when getting attributes from the response object
try:
response_status_code = response.status_code
except:
response_status_code = 0
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
try:
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': new
}))
except:
# The has client closed the stream.
if request:
# Cancel the backend?
request.close()
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=elapsed_time,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=None
)
return
message_num += 1
partial_response = b'' # Reset the partial response
# If there is no more data, break the loop
if not chunk:
break
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=elapsed_time,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=None,
is_error=not response
)
except:
traceback.print_exc()
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text']
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': msg
'text': generated_text
}))
else:
# Be extra careful when getting attributes from the response object
try:
response_status_code = response.status_code
except:
response_status_code = 0
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
try:
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': new
}))
except:
# The has client closed the stream.
if request:
request.close()
try:
ws.close()
except:
pass
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=elapsed_time,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=tokenize(generated_text, handler.backend_url)
)
return
message_num += 1
partial_response = b'' # Reset the partial response
# If there is no more data, break the loop
if not chunk:
break
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=elapsed_time,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=tokenize(generated_text, handler.backend_url),
is_error=not response
)
except:
traceback.print_exc()
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text']
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': generated_text
}))
if request:
request.close()
ws.close()
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=None,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=tokenize(generated_text, handler.backend_url),
is_error=True
)
return
finally:
# The worker incremented it, we'll decrement it.
decrement_ip_count(handler.client_ip, 'processing_ips')
decr_active_workers(handler.selected_model, handler.backend_url)
try:
ws.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
except:
# The client closed the stream.
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=elapsed_time,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=tokenize(generated_text, handler.backend_url)
)
try:
ws.close() # this is important if we encountered and error and exited early.
except:
pass
if request:
request.close()
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=None,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=None,
is_error=True
)
return
finally:
# The worker incremented it, we'll decrement it.
decrement_ip_count(handler.client_ip, 'processing_ips')
decr_active_workers(handler.selected_model, handler.backend_url)
try:
ws.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
except:
# The client closed the stream.
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=elapsed_time,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=None
)
finally:
try:
# Must close the connection or greenlets will complain.
ws.close()
except:
pass

View File

@ -11,19 +11,23 @@ from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip
def worker():
while True:
(request_json_body, client_ip, token, parameters, backend_url), event_id, selected_model = priority_queue.get()
if not backend_url:
backend_url = get_a_cluster_backend(selected_model)
backend_info = cluster_config.get_backend(backend_url)
# The backend could have died between when the request was
# submitted and now, so let's double check it's still online.
if not backend_info['online']:
old = backend_url
backend_url = get_a_cluster_backend()
backend_info = cluster_config.get_backend(backend_url)
print(f'Backend {old} offline. Request was redirected to {backend_url}')
del old
del old # gc
if not selected_model:
selected_model = backend_info['model']
# This wait time is "invisible", meaning the worker may as
# This wait time will be "invisible", meaning the worker may as
# well be still waiting to get an item from the queue.
need_to_wait(backend_url)
@ -32,7 +36,8 @@ def worker():
if not request_json_body:
# This was a dummy request from the websocket handlers.
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
# We're going to let the websocket handler decrement
# processing_ips and active_gen_workers.
event = DataEvent(event_id)
event.set((True, None, None))
continue

0
other/vllm/vllm_api_server.py Normal file → Executable file
View File

View File

@ -13,6 +13,4 @@ openai~=0.28.0
urllib3~=2.0.4
flask-sock==0.6.0
gunicorn==21.2.0
redis==5.0.1
aiohttp==3.8.5
asyncio==3.4.3
redis==5.0.1

View File

@ -24,6 +24,7 @@ from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp
from llm_server.sock import init_socketio
# TODO: implement blind RRD controlled via header and only used when there is a queue on the primary backend(s)
# TODO: is frequency penalty the same as ooba repetition penalty???
# TODO: make sure openai_moderation_enabled works on websockets, completions, and chat completions
# TODO: if a backend is at its limit of concurrent requests, choose a different one
@ -93,7 +94,6 @@ create_db()
def home():
base_client_api = redis.get('base_client_api', dtype=str)
stats = generate_stats()
model_choices, default_backend_info = get_model_choices()
if default_backend_info['queued'] == 0 and default_backend_info['queued'] >= opts.concurrent_gens: