fix processing not being decremented on streaming, fix confusion over queue, adjust stop sequences
This commit is contained in:
parent
4f226ae38e
commit
94141b8ecf
|
@ -5,7 +5,7 @@ from llm_server.cluster.backend import get_a_cluster_backend, get_backends_from_
|
|||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.routes.queue import priority_queue
|
||||
from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers
|
||||
from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers_model
|
||||
|
||||
|
||||
# TODO: give this a better name!
|
||||
|
@ -30,7 +30,7 @@ def get_model_choices(regen: bool = False):
|
|||
if backend_info.get('average_generation_elapsed_sec'):
|
||||
avg_gen_per_worker.append(backend_info['average_generation_elapsed_sec'])
|
||||
|
||||
active_gen_workers = get_active_gen_workers(model)
|
||||
active_gen_workers = get_active_gen_workers_model(model)
|
||||
proompters_in_queue = priority_queue.len(model)
|
||||
|
||||
if len(avg_gen_per_worker):
|
||||
|
|
|
@ -2,15 +2,15 @@ import json
|
|||
import time
|
||||
import traceback
|
||||
from threading import Thread
|
||||
from typing import Union
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.database.conn import database
|
||||
from llm_server.llm.vllm import tokenize
|
||||
from llm_server.llm import get_token_count
|
||||
|
||||
|
||||
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, backend_url, response_tokens: int = None, is_error: bool = False):
|
||||
def log_prompt(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
|
||||
def background_task():
|
||||
nonlocal ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, backend_url, response_tokens, is_error
|
||||
# Try not to shove JSON into the database.
|
||||
|
@ -23,10 +23,10 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
|
|||
except:
|
||||
pass
|
||||
|
||||
prompt_tokens = llm_server.llm.get_token_count(prompt, backend_url)
|
||||
prompt_tokens = get_token_count(prompt, backend_url)
|
||||
if not is_error:
|
||||
if not response_tokens:
|
||||
response_tokens = llm_server.llm.get_token_count(response, backend_url)
|
||||
response_tokens = get_token_count(response, backend_url)
|
||||
else:
|
||||
response_tokens = None
|
||||
|
||||
|
|
|
@ -3,6 +3,9 @@ from llm_server.custom_redis import redis
|
|||
|
||||
|
||||
def get_token_count(prompt: str, backend_url: str):
|
||||
assert isinstance(prompt, str)
|
||||
assert isinstance(backend_url, str)
|
||||
|
||||
backend_mode = redis.get('backend_mode', dtype=str)
|
||||
if backend_mode == 'vllm':
|
||||
return vllm.tokenize(prompt, backend_url)
|
||||
|
|
|
@ -8,11 +8,11 @@ def oai_to_vllm(request_json_body, hashes: bool, mode):
|
|||
request_json_body['stop'] = []
|
||||
|
||||
if hashes:
|
||||
request_json_body['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
|
||||
request_json_body['stop'].extend(['### INSTRUCTION', '### USER', '### ASSISTANT', '### RESPONSE'])
|
||||
if opts.openai_force_no_hashes:
|
||||
request_json_body['stop'].append('### ')
|
||||
else:
|
||||
request_json_body['stop'].extend(['\nuser:', '\nassistant:'])
|
||||
request_json_body['stop'].extend(['user:', 'assistant:'])
|
||||
|
||||
if request_json_body.get('frequency_penalty', 0) < -2:
|
||||
request_json_body['frequency_penalty'] = -2
|
||||
|
|
|
@ -8,6 +8,9 @@ from llm_server import opts
|
|||
|
||||
def tokenize(prompt: str, backend_url: str) -> int:
|
||||
assert backend_url
|
||||
assert isinstance(prompt, str)
|
||||
assert isinstance(backend_url, str)
|
||||
|
||||
if not prompt:
|
||||
# The tokenizers have issues when the prompt is None.
|
||||
return 0
|
||||
|
|
|
@ -62,7 +62,7 @@ def openai_chat_completions():
|
|||
}
|
||||
|
||||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url)
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
|
||||
if not event:
|
||||
log_prompt(
|
||||
handler.client_ip,
|
||||
|
|
|
@ -100,7 +100,7 @@ def openai_completions():
|
|||
}
|
||||
|
||||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url)
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
|
||||
if not event:
|
||||
log_prompt(
|
||||
handler.client_ip,
|
||||
|
|
|
@ -8,11 +8,11 @@ from uuid import uuid4
|
|||
import flask
|
||||
from flask import Response, jsonify, make_response
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.model_choices import get_model_choices
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.database import is_api_key_moderated, log_prompt
|
||||
from llm_server.llm import get_token_count
|
||||
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
|
||||
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
|
@ -110,9 +110,8 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
response = re.sub(ANTI_RESPONSE_RE, '', response)
|
||||
response = re.sub(ANTI_CONTINUATION_RE, '', response)
|
||||
|
||||
# TODO: async/await
|
||||
prompt_tokens = llm_server.llm.get_token_count(prompt, self.backend_url)
|
||||
response_tokens = llm_server.llm.get_token_count(response, self.backend_url)
|
||||
prompt_tokens = get_token_count(prompt, self.backend_url)
|
||||
response_tokens = get_token_count(response, self.backend_url)
|
||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||
|
||||
response = make_response(jsonify({
|
||||
|
|
|
@ -27,7 +27,6 @@ class RedisPriorityQueue:
|
|||
|
||||
def put(self, item, priority, selected_model):
|
||||
event = DataEvent()
|
||||
|
||||
# Check if the IP is already in the dictionary and if it has reached the limit
|
||||
ip_count = self.redis.hget('queued_ip_count', item[1])
|
||||
if ip_count:
|
||||
|
@ -99,16 +98,20 @@ class DataEvent:
|
|||
priority_queue = RedisPriorityQueue()
|
||||
|
||||
|
||||
def update_active_workers(key: str, operation: str):
|
||||
if operation == 'incr':
|
||||
redis.incr(f'active_gen_workers:{key}')
|
||||
elif operation == 'decr':
|
||||
redis.decr(f'active_gen_workers:{key}')
|
||||
if redis.get(f'active_gen_workers:{key}', default=0, dtype=int) < 0:
|
||||
redis.set(f'active_gen_workers:{key}', 0)
|
||||
|
||||
|
||||
def incr_active_workers(selected_model: str, backend_url: str):
|
||||
redis.incr(f'active_gen_workers:{selected_model}')
|
||||
redis.incr(f'active_gen_workers:{backend_url}')
|
||||
update_active_workers(selected_model, 'incr')
|
||||
update_active_workers(backend_url, 'incr')
|
||||
|
||||
|
||||
def decr_active_workers(selected_model: str, backend_url: str):
|
||||
redis.decr(f'active_gen_workers:{selected_model}')
|
||||
if redis.get(f'active_gen_workers:{selected_model}', 0, dtype=int) < 0:
|
||||
redis.set(f'active_gen_workers:{selected_model}', 0)
|
||||
|
||||
redis.decr(f'active_gen_workers:{backend_url}')
|
||||
if redis.get(f'active_gen_workers:{backend_url}', 0, dtype=int) < 0:
|
||||
redis.set(f'active_gen_workers:{backend_url}', 0)
|
||||
update_active_workers(selected_model, 'decr')
|
||||
update_active_workers(backend_url, 'decr')
|
||||
|
|
|
@ -36,6 +36,7 @@ class RequestHandler:
|
|||
self.token_priority, self.token_simultaneous_ip = get_token_ratelimit(self.token)
|
||||
self.backend_url = get_a_cluster_backend(selected_model)
|
||||
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
|
||||
self.selected_model = self.cluster_backend_info['model']
|
||||
|
||||
if not self.cluster_backend_info.get('mode'):
|
||||
print(selected_model, self.backend_url, self.cluster_backend_info)
|
||||
|
@ -43,7 +44,6 @@ class RequestHandler:
|
|||
self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url)
|
||||
self.parameters = None
|
||||
self.used = False
|
||||
self.selected_model = selected_model
|
||||
redis.zadd('recent_prompters', {self.client_ip: time.time()})
|
||||
|
||||
def get_auth_token(self):
|
||||
|
|
|
@ -15,13 +15,8 @@ def get_total_proompts():
|
|||
return count
|
||||
|
||||
|
||||
def get_active_gen_workers(selected_model: str = None, ):
|
||||
active_gen_workers = redis.get(f'active_gen_workers:{selected_model}')
|
||||
if active_gen_workers is None:
|
||||
count = 0
|
||||
else:
|
||||
count = int(active_gen_workers)
|
||||
return count
|
||||
def get_active_gen_workers_model(selected_model: str = None):
|
||||
return redis.get(f'active_gen_workers:{selected_model}', dtype=int, default=0)
|
||||
|
||||
|
||||
def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
|
||||
|
|
|
@ -11,7 +11,6 @@ from ..queue import decr_active_workers, decrement_ip_count, priority_queue
|
|||
from ... import opts
|
||||
from ...database.database import log_prompt
|
||||
from ...llm.generator import generator
|
||||
from ...llm.vllm import tokenize
|
||||
from ...sock import sock
|
||||
|
||||
|
||||
|
@ -45,7 +44,6 @@ def do_stream(ws, model_name):
|
|||
'event': 'stream_end',
|
||||
'message_num': 1
|
||||
}))
|
||||
ws.close()
|
||||
log_prompt(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
|
@ -56,7 +54,7 @@ def do_stream(ws, model_name):
|
|||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.cluster_backend_info,
|
||||
response_tokens=tokenize(generated_text, handler.backend_url),
|
||||
response_tokens=None,
|
||||
is_error=True
|
||||
)
|
||||
|
||||
|
@ -67,195 +65,192 @@ def do_stream(ws, model_name):
|
|||
r_url = request.url
|
||||
message_num = 0
|
||||
|
||||
while ws.connected:
|
||||
message = ws.receive()
|
||||
request_valid_json, request_json_body = validate_json(message)
|
||||
try:
|
||||
while ws.connected:
|
||||
message = ws.receive()
|
||||
request_valid_json, request_json_body = validate_json(message)
|
||||
|
||||
if not request_valid_json or not request_json_body.get('prompt'):
|
||||
ws.close()
|
||||
return 'Invalid JSON', 400
|
||||
else:
|
||||
if opts.mode != 'vllm':
|
||||
# TODO: implement other backends
|
||||
raise NotImplementedError
|
||||
|
||||
auth_failure = require_api_key(request_json_body)
|
||||
if auth_failure:
|
||||
ws.close()
|
||||
return auth_failure
|
||||
|
||||
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
|
||||
generated_text = ''
|
||||
input_prompt = request_json_body['prompt']
|
||||
response_status_code = 0
|
||||
start_time = time.time()
|
||||
|
||||
err_msg = None
|
||||
if handler.is_client_ratelimited():
|
||||
r, _ = handler.handle_ratelimited(do_log=False)
|
||||
err_msg = r.json['results'][0]['text']
|
||||
if not request_valid_json or not request_json_body.get('prompt'):
|
||||
return 'Invalid JSON', 400
|
||||
else:
|
||||
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
|
||||
if not request_valid:
|
||||
err_msg = invalid_response[0].json['results'][0]['text']
|
||||
if err_msg:
|
||||
send_err_and_quit(err_msg)
|
||||
return
|
||||
if opts.mode != 'vllm':
|
||||
# TODO: implement other backends
|
||||
raise NotImplementedError
|
||||
|
||||
llm_request = {
|
||||
**handler.parameters,
|
||||
'prompt': input_prompt,
|
||||
'stream': True,
|
||||
}
|
||||
auth_failure = require_api_key(request_json_body)
|
||||
if auth_failure:
|
||||
return auth_failure
|
||||
|
||||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url)
|
||||
if not event:
|
||||
r, _ = handler.handle_ratelimited()
|
||||
err_msg = r.json['results'][0]['text']
|
||||
send_err_and_quit(err_msg)
|
||||
return
|
||||
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
|
||||
generated_text = ''
|
||||
input_prompt = request_json_body['prompt']
|
||||
response_status_code = 0
|
||||
start_time = time.time()
|
||||
|
||||
# Wait for a worker to get our request and discard it.
|
||||
_, _, _ = event.wait()
|
||||
err_msg = None
|
||||
if handler.is_client_ratelimited():
|
||||
r, _ = handler.handle_ratelimited(do_log=False)
|
||||
err_msg = r.json['results'][0]['text']
|
||||
else:
|
||||
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
|
||||
if not request_valid:
|
||||
err_msg = invalid_response[0].json['results'][0]['text']
|
||||
if err_msg:
|
||||
send_err_and_quit(err_msg)
|
||||
return
|
||||
|
||||
try:
|
||||
response = generator(llm_request, handler.backend_url)
|
||||
if not response:
|
||||
error_msg = 'Failed to reach backend while streaming.'
|
||||
print('Streaming failed:', error_msg)
|
||||
msg = handler.handle_error(error_msg)[0].json['results'][0]['text']
|
||||
llm_request = {
|
||||
**handler.parameters,
|
||||
'prompt': input_prompt,
|
||||
'stream': True,
|
||||
}
|
||||
|
||||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
|
||||
if not event:
|
||||
r, _ = handler.handle_ratelimited()
|
||||
err_msg = r.json['results'][0]['text']
|
||||
send_err_and_quit(err_msg)
|
||||
return
|
||||
|
||||
# Wait for a worker to get our request and discard it.
|
||||
_, _, _ = event.wait()
|
||||
|
||||
try:
|
||||
response = generator(llm_request, handler.backend_url)
|
||||
|
||||
if not response:
|
||||
error_msg = 'Failed to reach backend while streaming.'
|
||||
print('Streaming failed:', error_msg)
|
||||
msg = handler.handle_error(error_msg)[0].json['results'][0]['text']
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': msg
|
||||
}))
|
||||
else:
|
||||
# Be extra careful when getting attributes from the response object
|
||||
try:
|
||||
response_status_code = response.status_code
|
||||
except:
|
||||
response_status_code = 0
|
||||
|
||||
partial_response = b''
|
||||
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
|
||||
generated_text = generated_text + new
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
try:
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': new
|
||||
}))
|
||||
except:
|
||||
# The has client closed the stream.
|
||||
if request:
|
||||
# Cancel the backend?
|
||||
request.close()
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_prompt(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=generated_text,
|
||||
gen_time=elapsed_time,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.backend_url,
|
||||
response_tokens=None
|
||||
)
|
||||
return
|
||||
|
||||
message_num += 1
|
||||
partial_response = b'' # Reset the partial response
|
||||
|
||||
# If there is no more data, break the loop
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_prompt(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=generated_text,
|
||||
gen_time=elapsed_time,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.backend_url,
|
||||
response_tokens=None,
|
||||
is_error=not response
|
||||
)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text']
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': msg
|
||||
'text': generated_text
|
||||
}))
|
||||
else:
|
||||
# Be extra careful when getting attributes from the response object
|
||||
try:
|
||||
response_status_code = response.status_code
|
||||
except:
|
||||
response_status_code = 0
|
||||
|
||||
partial_response = b''
|
||||
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
|
||||
generated_text = generated_text + new
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
try:
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': new
|
||||
}))
|
||||
except:
|
||||
# The has client closed the stream.
|
||||
if request:
|
||||
request.close()
|
||||
try:
|
||||
ws.close()
|
||||
except:
|
||||
pass
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_prompt(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=generated_text,
|
||||
gen_time=elapsed_time,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.backend_url,
|
||||
response_tokens=tokenize(generated_text, handler.backend_url)
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
message_num += 1
|
||||
partial_response = b'' # Reset the partial response
|
||||
|
||||
# If there is no more data, break the loop
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_prompt(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=generated_text,
|
||||
gen_time=elapsed_time,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.backend_url,
|
||||
response_tokens=tokenize(generated_text, handler.backend_url),
|
||||
is_error=not response
|
||||
)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text']
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': generated_text
|
||||
}))
|
||||
if request:
|
||||
request.close()
|
||||
ws.close()
|
||||
log_prompt(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=generated_text,
|
||||
gen_time=None,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.backend_url,
|
||||
response_tokens=tokenize(generated_text, handler.backend_url),
|
||||
is_error=True
|
||||
)
|
||||
return
|
||||
finally:
|
||||
# The worker incremented it, we'll decrement it.
|
||||
decrement_ip_count(handler.client_ip, 'processing_ips')
|
||||
decr_active_workers(handler.selected_model, handler.backend_url)
|
||||
try:
|
||||
ws.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
}))
|
||||
except:
|
||||
# The client closed the stream.
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_prompt(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=generated_text,
|
||||
gen_time=elapsed_time,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.backend_url,
|
||||
response_tokens=tokenize(generated_text, handler.backend_url)
|
||||
)
|
||||
try:
|
||||
ws.close() # this is important if we encountered and error and exited early.
|
||||
except:
|
||||
pass
|
||||
if request:
|
||||
request.close()
|
||||
log_prompt(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=generated_text,
|
||||
gen_time=None,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.backend_url,
|
||||
response_tokens=None,
|
||||
is_error=True
|
||||
)
|
||||
return
|
||||
finally:
|
||||
# The worker incremented it, we'll decrement it.
|
||||
decrement_ip_count(handler.client_ip, 'processing_ips')
|
||||
decr_active_workers(handler.selected_model, handler.backend_url)
|
||||
try:
|
||||
ws.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
}))
|
||||
except:
|
||||
# The client closed the stream.
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_prompt(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=generated_text,
|
||||
gen_time=elapsed_time,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.backend_url,
|
||||
response_tokens=None
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
# Must close the connection or greenlets will complain.
|
||||
ws.close()
|
||||
except:
|
||||
pass
|
||||
|
|
|
@ -11,19 +11,23 @@ from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip
|
|||
def worker():
|
||||
while True:
|
||||
(request_json_body, client_ip, token, parameters, backend_url), event_id, selected_model = priority_queue.get()
|
||||
if not backend_url:
|
||||
backend_url = get_a_cluster_backend(selected_model)
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
|
||||
# The backend could have died between when the request was
|
||||
# submitted and now, so let's double check it's still online.
|
||||
if not backend_info['online']:
|
||||
old = backend_url
|
||||
backend_url = get_a_cluster_backend()
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
print(f'Backend {old} offline. Request was redirected to {backend_url}')
|
||||
del old
|
||||
del old # gc
|
||||
|
||||
if not selected_model:
|
||||
selected_model = backend_info['model']
|
||||
|
||||
# This wait time is "invisible", meaning the worker may as
|
||||
# This wait time will be "invisible", meaning the worker may as
|
||||
# well be still waiting to get an item from the queue.
|
||||
need_to_wait(backend_url)
|
||||
|
||||
|
@ -32,7 +36,8 @@ def worker():
|
|||
|
||||
if not request_json_body:
|
||||
# This was a dummy request from the websocket handlers.
|
||||
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
|
||||
# We're going to let the websocket handler decrement
|
||||
# processing_ips and active_gen_workers.
|
||||
event = DataEvent(event_id)
|
||||
event.set((True, None, None))
|
||||
continue
|
||||
|
|
|
@ -14,5 +14,3 @@ urllib3~=2.0.4
|
|||
flask-sock==0.6.0
|
||||
gunicorn==21.2.0
|
||||
redis==5.0.1
|
||||
aiohttp==3.8.5
|
||||
asyncio==3.4.3
|
|
@ -24,6 +24,7 @@ from llm_server.routes.server_error import handle_server_error
|
|||
from llm_server.routes.v1 import bp
|
||||
from llm_server.sock import init_socketio
|
||||
|
||||
# TODO: implement blind RRD controlled via header and only used when there is a queue on the primary backend(s)
|
||||
# TODO: is frequency penalty the same as ooba repetition penalty???
|
||||
# TODO: make sure openai_moderation_enabled works on websockets, completions, and chat completions
|
||||
# TODO: if a backend is at its limit of concurrent requests, choose a different one
|
||||
|
@ -93,7 +94,6 @@ create_db()
|
|||
def home():
|
||||
base_client_api = redis.get('base_client_api', dtype=str)
|
||||
stats = generate_stats()
|
||||
|
||||
model_choices, default_backend_info = get_model_choices()
|
||||
|
||||
if default_backend_info['queued'] == 0 and default_backend_info['queued'] >= opts.concurrent_gens:
|
||||
|
|
Reference in New Issue