finish openai endpoints

This commit is contained in:
Cyberes 2023-10-01 16:04:53 -06:00
parent 2a3ff7e21e
commit f7e9687527
15 changed files with 311 additions and 242 deletions

View File

@ -43,12 +43,10 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database.
### Use
If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis.
If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis. You may need to wait a few minutes for the daemon to populate the database.
Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app`
### To Do
- [x] Implement streaming

View File

@ -14,8 +14,11 @@ def test_backend(backend_url: str, test_prompt: bool = False):
"temperature": 0,
"max_new_tokens": 3,
}
success, response, err = generator(data, backend_url, timeout=10)
if not success or not response or err:
try:
success, response, err = generator(data, backend_url, timeout=10)
if not success or not response or err:
return False, {}
except:
return False, {}
i = get_info(backend_url, backend_info['mode'])
if not i.get('model'):

View File

@ -1,12 +1,14 @@
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
def generator(request_json_body, cluster_backend, timeout: int = None):
if opts.mode == 'oobabooga':
mode = cluster_config.get_backend(cluster_backend)['mode']
if mode == 'ooba':
# from .oobabooga.generate import generate
# return generate(request_json_body)
raise NotImplementedError
elif opts.mode == 'vllm':
elif mode == 'vllm':
from .vllm.generate import generate
return generate(request_json_body, cluster_backend, timeout=timeout)
else:

View File

@ -12,6 +12,7 @@ class LLMBackend:
def __init__(self, backend_url: str):
self.backend_url = backend_url
self.backend_info = cluster_config.get_backend(self.backend_url)
def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers):
raise NotImplementedError
@ -44,8 +45,7 @@ class LLMBackend:
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
prompt_len = get_token_count(prompt, self.backend_url)
token_limit = cluster_config.get_backend(self.backend_url)['model_config']['max_position_embeddings']
token_limit = self.backend_info['model_config']['max_position_embeddings']
if prompt_len > token_limit - 10:
model_name = redis.get('running_model', 'NO MODEL ERROR', dtype=str)
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {model_name}). Please lower your context size'
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {self.backend_info["model"]}). Please lower your context size'
return True, None

View File

@ -20,19 +20,17 @@ def generate_oai_string(length=24):
def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]:
tokenizer = tiktoken.get_encoding("cl100k_base")
def get_token_count_tiktoken_thread(msg):
return len(tokenizer.encode(msg["content"]))
def get_token_count_thread(msg):
return get_token_count(msg["content"], backend_url)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_tiktoken_thread, prompt))
token_counts = list(executor.map(get_token_count_thread, prompt))
total_tokens = sum(token_counts)
formatting_tokens = len(tokenizer.encode(transform_messages_to_prompt(prompt))) - total_tokens
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
# If total tokens exceed the limit, start trimming
if total_tokens > context_token_limit:
if total_tokens + formatting_tokens > context_token_limit:
while True:
while total_tokens + formatting_tokens > context_token_limit:
# Calculate the index to start removing messages from
@ -45,15 +43,11 @@ def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int,
if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt):
break
def get_token_count_thread(msg):
return get_token_count(msg["content"], backend_url)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_thread, prompt))
total_tokens = sum(token_counts)
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
if total_tokens + formatting_tokens > context_token_limit:
# Start over, but this time calculate the token count using the backend
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
@ -65,11 +59,7 @@ def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int,
def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str:
tokenizer = tiktoken.get_encoding("cl100k_base")
def get_token_count_tiktoken_thread(msg):
return len(tokenizer.encode(msg))
token_count = get_token_count_tiktoken_thread(prompt)
token_count = get_token_count(prompt, backend_url)
# If total tokens exceed the limit, start trimming
if token_count > context_token_limit:
@ -80,21 +70,17 @@ def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str)
while remove_index < len(prompt):
prompt = prompt[:remove_index] + prompt[remove_index + 100:]
token_count = get_token_count_tiktoken_thread(prompt)
token_count = len(tokenizer.encode(prompt))
if token_count <= context_token_limit or remove_index == len(prompt):
break
def get_token_count_thread(msg):
return get_token_count(msg, backend_url)
token_count = get_token_count_thread(prompt)
token_count = get_token_count(prompt, backend_url)
if token_count > context_token_limit:
# Start over, but this time calculate the token count using the backend
token_count = get_token_count_thread(prompt)
token_count = get_token_count(prompt, backend_url)
else:
break
print(token_count)
return prompt

View File

@ -1,29 +1,35 @@
import requests
import asyncio
import aiohttp
import tiktoken
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
def tokenize(prompt: str, backend_url: str) -> int:
assert backend_url
if not prompt:
# The tokenizers have issues when the prompt is None.
return 0
tokenizer = tiktoken.get_encoding("cl100k_base")
token_limit = cluster_config.get_backend(backend_url)['model_config']['max_position_embeddings']
# First we tokenize it locally to determine if it's worth sending it to the backend.
initial_estimate = len(tokenizer.encode(prompt))
if initial_estimate <= token_limit + 200:
try:
r = requests.post(f'{backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
j = r.json()
return j['length']
except Exception as e:
print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}')
return len(tokenizer.encode(prompt)) + 10
else:
# If the result was greater than our context size, return the estimate.
# We won't be sending it through the backend so it does't need to be accurage.
return initial_estimate
async def run():
tokenizer = tiktoken.get_encoding("cl100k_base")
async def send_chunk(chunk):
try:
async with session.post(f'{backend_url}/tokenize', json={'input': chunk}, verify_ssl=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) as response:
j = await response.json()
return j['length']
except Exception as e:
print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}')
return len(tokenizer.encode(chunk)) + 10
chunk_size = 300
chunks = [prompt[i:i + chunk_size] for i in range(0, len(prompt), chunk_size)]
async with aiohttp.ClientSession() as session:
tasks = [send_chunk(chunk) for chunk in chunks]
lengths = await asyncio.gather(*tasks)
return sum(lengths)
return asyncio.run(run())

View File

@ -13,7 +13,7 @@ class OobaRequestHandler(RequestHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def handle_request(self):
def handle_request(self, return_ok: bool = True):
assert not self.used
request_valid, invalid_response = self.validate_request()
@ -25,14 +25,19 @@ class OobaRequestHandler(RequestHandler):
llm_request = {**self.parameters, 'prompt': prompt}
_, backend_response = self.generate_response(llm_request)
return backend_response
if return_ok:
# Always return 200 so ST displays our error messages
return backend_response[0], 200
else:
# The OpenAI route needs to detect 429 errors.
return backend_response
def handle_ratelimited(self, do_log: bool = True):
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
backend_response = self.handle_error(msg)
if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
return backend_response[0], 200 # We only return the response from handle_error(), not the error code
return backend_response[0], 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'

View File

@ -8,10 +8,11 @@ from llm_server.custom_redis import redis
from . import openai_bp
from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
from ... import opts
from ...database.database import log_prompt
from ...llm.generator import generator
from ...llm.openai.oai_to_vllm import oai_to_vllm
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
@ -24,11 +25,6 @@ def openai_chat_completions():
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
else:
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body)
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends
raise NotImplementedError
if not request_json_body.get('stream'):
try:
return handler.handle_request()
@ -37,30 +33,51 @@ def openai_chat_completions():
return 'Internal server error', 500
else:
if not opts.enable_streaming:
# TODO: return a proper OAI error message
return 'disabled', 401
return 'DISABLED', 401
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
if opts.openai_silent_trim:
handler.request_json_body['messages'] = trim_messages_to_fit(request_json_body['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
else:
handler.prompt = transform_messages_to_prompt(handler.request.json['messages'])
response_status_code = 0
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
return invalid_response
else:
if opts.openai_silent_trim:
oai_messages = trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
else:
oai_messages = handler.request.json['messages']
handler.prompt = transform_messages_to_prompt(oai_messages)
handler.parameters = oai_to_vllm(handler.parameters, hashes=True, mode=handler.cluster_backend_info['mode'])
msg_to_backend = {
**handler.parameters,
'prompt': handler.prompt,
'stream': True,
}
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url)
if not event:
log_prompt(
handler.client_ip,
handler.token,
handler.prompt,
None,
None,
handler.parameters,
request.headers,
response_status_code,
request.url,
handler.backend_url,
)
return handler.handle_ratelimited()
# Wait for a worker to get our request and discard it.
_, _, _ = event.wait()
try:
response = generator(msg_to_backend, handler.backend_url)
r_headers = dict(request.headers)
@ -69,57 +86,61 @@ def openai_chat_completions():
oai_string = generate_oai_string(30)
def generate():
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
try:
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
data = {
"id": f"chatcmpl-{oai_string}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": new
},
"finish_reason": None
}
]
}
yield f'data: {json.dumps(data)}\n\n'
data = {
"id": f"chatcmpl-{oai_string}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": new
},
"finish_reason": None
}
]
}
yield f'data: {json.dumps(data)}\n\n'
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(
handler.client_ip,
handler.token,
handler.prompt,
generated_text,
elapsed_time,
handler.parameters,
r_headers,
response_status_code,
r_url,
handler.backend_url,
)
log_prompt(
handler.client_ip,
handler.token,
handler.prompt,
generated_text,
elapsed_time,
handler.parameters,
r_headers,
response_status_code,
r_url,
handler.backend_url,
)
finally:
# The worker incremented it, we'll decrement it.
decrement_ip_count(handler.client_ip, 'processing_ips')
decr_active_workers(handler.selected_model, handler.backend_url)
return Response(generate(), mimetype='text/event-stream')
except:
# TODO: simulate OAI here
raise Exception
except Exception:
traceback.print_exc()
return 'INTERNAL SERVER', 500

View File

@ -8,6 +8,7 @@ from llm_server.custom_redis import redis
from . import openai_bp
from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
from ... import opts
from ...database.database import log_prompt
from ...llm import get_token_count
@ -24,80 +25,98 @@ def openai_completions():
if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else:
try:
handler = OobaRequestHandler(incoming_request=request)
handler = OobaRequestHandler(incoming_request=request)
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends
raise NotImplementedError
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends
raise NotImplementedError
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
# Convert parameters to the selected backend type
if opts.openai_silent_trim:
handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
else:
# The handle_request() call below will load the prompt so we don't have
# to do anything else here.
pass
if opts.openai_silent_trim:
handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
else:
# The handle_request() call below will load the prompt so we don't have
# to do anything else here.
pass
if not request_json_body.get('stream'):
response, status_code = handler.handle_request()
if status_code != 200:
return status_code
output = response.json['results'][0]['text']
if not request_json_body.get('stream'):
response, status_code = handler.handle_request(return_ok=False)
if status_code == 429:
return handler.handle_ratelimited()
output = response.json['results'][0]['text']
# TODO: async/await
prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url)
response_tokens = get_token_count(output, handler.backend_url)
running_model = redis.get('running_model', 'ERROR', dtype=str)
# TODO: async/await
prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url)
response_tokens = get_token_count(output, handler.backend_url)
running_model = redis.get('running_model', 'ERROR', dtype=str)
response = jsonify({
"id": f"cmpl-{generate_oai_string(30)}",
"object": "text_completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
"choices": [
{
"text": output,
"index": 0,
"logprobs": None,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
response = jsonify({
"id": f"cmpl-{generate_oai_string(30)}",
"object": "text_completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
"choices": [
{
"text": output,
"index": 0,
"logprobs": None,
"finish_reason": "stop"
}
})
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
})
stats = redis.get('proxy_stats', dtype=dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response, 200
stats = redis.get('proxy_stats', dtype=dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response, 200
else:
if not opts.enable_streaming:
return 'DISABLED', 401
response_status_code = 0
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
return invalid_response
else:
if not opts.enable_streaming:
# TODO: return a proper OAI error message
return 'disabled', 401
handler.prompt = handler.request_json_body['prompt']
msg_to_backend = {
**handler.parameters,
'prompt': handler.prompt,
'stream': True,
}
response_status_code = 0
start_time = time.time()
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url)
if not event:
log_prompt(
handler.client_ip,
handler.token,
handler.prompt,
None,
None,
handler.parameters,
request.headers,
response_status_code,
request.url,
handler.backend_url,
)
return handler.handle_ratelimited()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
# TODO: simulate OAI here
raise Exception('TODO: simulate OAI here')
else:
handler.prompt = handler.request_json_body['prompt']
msg_to_backend = {
**handler.parameters,
'prompt': handler.prompt,
'stream': True,
}
# Wait for a worker to get our request and discard it.
_, _, _ = event.wait()
try:
response = generator(msg_to_backend, handler.backend_url)
r_headers = dict(request.headers)
r_url = request.url
@ -105,57 +124,61 @@ def openai_completions():
oai_string = generate_oai_string(30)
def generate():
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
try:
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
data = {
"id": f"chatcmpl-{oai_string}",
"object": "text_completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": new
},
"finish_reason": None
}
]
}
yield f'data: {json.dumps(data)}\n\n'
data = {
"id": f"cmpl-{oai_string}",
"object": "text_completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": new
},
"finish_reason": None
}
]
}
yield f'data: {json.dumps(data)}\n\n'
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(
handler.client_ip,
handler.token,
handler.prompt,
generated_text,
elapsed_time,
handler.parameters,
r_headers,
response_status_code,
r_url,
handler.backend_url,
)
log_prompt(
handler.client_ip,
handler.token,
handler.prompt,
generated_text,
elapsed_time,
handler.parameters,
r_headers,
response_status_code,
r_url,
handler.backend_url,
)
finally:
# The worker incremented it, we'll decrement it.
decrement_ip_count(handler.client_ip, 'processing_ips')
decr_active_workers(handler.selected_model, handler.backend_url)
return Response(generate(), mimetype='text/event-stream')
except Exception:
traceback.print_exc()
return 'Internal Server Error', 500
except Exception:
traceback.print_exc()
return 'INTERNAL SERVER', 500

View File

@ -10,8 +10,9 @@ from flask import Response, jsonify, make_response
import llm_server
from llm_server import opts
from llm_server.cluster.model_choices import get_model_choices
from llm_server.custom_redis import redis
from llm_server.database.database import is_api_key_moderated
from llm_server.database.database import is_api_key_moderated, log_prompt
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from llm_server.routes.request_handler import RequestHandler
@ -70,9 +71,24 @@ class OpenAIRequestHandler(RequestHandler):
return backend_response, backend_response_status_code
def handle_ratelimited(self, do_log: bool = True):
# TODO: return a simulated OpenAI error message
# Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.
return 'Ratelimited', 429
_, default_backend_info = get_model_choices()
w = int(default_backend_info['estimated_wait']) if default_backend_info['estimated_wait'] > 0 else 2
response = jsonify({
"error": {
"message": "Rate limit reached on tokens per min. Limit: 10000 / min. Please try again in 6s. Contact us through our help center at help.openai.com if you continue to have issues.",
"type": "rate_limit_exceeded",
"param": None,
"code": None
}
})
response.headers['x-ratelimit-limit-requests'] = '2'
response.headers['x-ratelimit-remaining-requests'] = '0'
response.headers['x-ratelimit-reset-requests'] = f"{w}s"
if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
return response, 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
return jsonify({

View File

@ -209,7 +209,7 @@ class RequestHandler:
if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0:
return False
else:
print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} queued + processing.')
print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} already queued/processing.')
return True
def handle_request(self) -> Tuple[flask.Response, int]:

View File

@ -115,6 +115,10 @@ def do_stream(ws, model_name):
err_msg = r.json['results'][0]['text']
send_err_and_quit(err_msg)
return
# Wait for a worker to get our request and discard it.
_, _, _ = event.wait()
try:
response = generator(llm_request, handler.backend_url)
if not response:

View File

@ -6,6 +6,7 @@ from llm_server.custom_redis import flask_cache
from . import bp
from ... import opts
from ...cluster.backend import get_a_cluster_backend, get_backends_from_model, is_valid_model
from ...cluster.cluster_config import cluster_config
@bp.route('/v1/model', methods=['GET'])

View File

@ -21,8 +21,10 @@ def worker():
incr_active_workers(selected_model, backend_url)
if not request_json_body:
# This was a dummy request from the websocket handler.
# This was a dummy request from the websocket handlers.
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
event = DataEvent(event_id)
event.set((True, None, None))
continue
try:

View File

@ -13,4 +13,6 @@ openai~=0.28.0
urllib3~=2.0.4
flask-sock==0.6.0
gunicorn==21.2.0
redis==5.0.1
redis==5.0.1
aiohttp==3.8.5
asyncio==3.4.3