finish openai endpoints

This commit is contained in:
Cyberes 2023-10-01 16:04:53 -06:00
parent 2a3ff7e21e
commit f7e9687527
15 changed files with 311 additions and 242 deletions

View File

@ -43,12 +43,10 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database.
### Use ### Use
If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis. If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis. You may need to wait a few minutes for the daemon to populate the database.
Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app` Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app`
### To Do ### To Do
- [x] Implement streaming - [x] Implement streaming

View File

@ -14,8 +14,11 @@ def test_backend(backend_url: str, test_prompt: bool = False):
"temperature": 0, "temperature": 0,
"max_new_tokens": 3, "max_new_tokens": 3,
} }
success, response, err = generator(data, backend_url, timeout=10) try:
if not success or not response or err: success, response, err = generator(data, backend_url, timeout=10)
if not success or not response or err:
return False, {}
except:
return False, {} return False, {}
i = get_info(backend_url, backend_info['mode']) i = get_info(backend_url, backend_info['mode'])
if not i.get('model'): if not i.get('model'):

View File

@ -1,12 +1,14 @@
from llm_server import opts from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
def generator(request_json_body, cluster_backend, timeout: int = None): def generator(request_json_body, cluster_backend, timeout: int = None):
if opts.mode == 'oobabooga': mode = cluster_config.get_backend(cluster_backend)['mode']
if mode == 'ooba':
# from .oobabooga.generate import generate # from .oobabooga.generate import generate
# return generate(request_json_body) # return generate(request_json_body)
raise NotImplementedError raise NotImplementedError
elif opts.mode == 'vllm': elif mode == 'vllm':
from .vllm.generate import generate from .vllm.generate import generate
return generate(request_json_body, cluster_backend, timeout=timeout) return generate(request_json_body, cluster_backend, timeout=timeout)
else: else:

View File

@ -12,6 +12,7 @@ class LLMBackend:
def __init__(self, backend_url: str): def __init__(self, backend_url: str):
self.backend_url = backend_url self.backend_url = backend_url
self.backend_info = cluster_config.get_backend(self.backend_url)
def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers): def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers):
raise NotImplementedError raise NotImplementedError
@ -44,8 +45,7 @@ class LLMBackend:
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]: def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
prompt_len = get_token_count(prompt, self.backend_url) prompt_len = get_token_count(prompt, self.backend_url)
token_limit = cluster_config.get_backend(self.backend_url)['model_config']['max_position_embeddings'] token_limit = self.backend_info['model_config']['max_position_embeddings']
if prompt_len > token_limit - 10: if prompt_len > token_limit - 10:
model_name = redis.get('running_model', 'NO MODEL ERROR', dtype=str) return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {self.backend_info["model"]}). Please lower your context size'
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {model_name}). Please lower your context size'
return True, None return True, None

View File

@ -20,19 +20,17 @@ def generate_oai_string(length=24):
def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]: def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]:
tokenizer = tiktoken.get_encoding("cl100k_base") def get_token_count_thread(msg):
return get_token_count(msg["content"], backend_url)
def get_token_count_tiktoken_thread(msg):
return len(tokenizer.encode(msg["content"]))
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_tiktoken_thread, prompt)) token_counts = list(executor.map(get_token_count_thread, prompt))
total_tokens = sum(token_counts) total_tokens = sum(token_counts)
formatting_tokens = len(tokenizer.encode(transform_messages_to_prompt(prompt))) - total_tokens formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
# If total tokens exceed the limit, start trimming # If total tokens exceed the limit, start trimming
if total_tokens > context_token_limit: if total_tokens + formatting_tokens > context_token_limit:
while True: while True:
while total_tokens + formatting_tokens > context_token_limit: while total_tokens + formatting_tokens > context_token_limit:
# Calculate the index to start removing messages from # Calculate the index to start removing messages from
@ -45,15 +43,11 @@ def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int,
if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt): if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt):
break break
def get_token_count_thread(msg):
return get_token_count(msg["content"], backend_url)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_thread, prompt)) token_counts = list(executor.map(get_token_count_thread, prompt))
total_tokens = sum(token_counts) total_tokens = sum(token_counts)
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
if total_tokens + formatting_tokens > context_token_limit: if total_tokens + formatting_tokens > context_token_limit:
# Start over, but this time calculate the token count using the backend # Start over, but this time calculate the token count using the backend
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
@ -65,11 +59,7 @@ def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int,
def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str: def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str:
tokenizer = tiktoken.get_encoding("cl100k_base") tokenizer = tiktoken.get_encoding("cl100k_base")
token_count = get_token_count(prompt, backend_url)
def get_token_count_tiktoken_thread(msg):
return len(tokenizer.encode(msg))
token_count = get_token_count_tiktoken_thread(prompt)
# If total tokens exceed the limit, start trimming # If total tokens exceed the limit, start trimming
if token_count > context_token_limit: if token_count > context_token_limit:
@ -80,21 +70,17 @@ def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str)
while remove_index < len(prompt): while remove_index < len(prompt):
prompt = prompt[:remove_index] + prompt[remove_index + 100:] prompt = prompt[:remove_index] + prompt[remove_index + 100:]
token_count = get_token_count_tiktoken_thread(prompt) token_count = len(tokenizer.encode(prompt))
if token_count <= context_token_limit or remove_index == len(prompt): if token_count <= context_token_limit or remove_index == len(prompt):
break break
def get_token_count_thread(msg): token_count = get_token_count(prompt, backend_url)
return get_token_count(msg, backend_url)
token_count = get_token_count_thread(prompt)
if token_count > context_token_limit: if token_count > context_token_limit:
# Start over, but this time calculate the token count using the backend # Start over, but this time calculate the token count using the backend
token_count = get_token_count_thread(prompt) token_count = get_token_count(prompt, backend_url)
else: else:
break break
print(token_count)
return prompt return prompt

View File

@ -1,29 +1,35 @@
import requests import asyncio
import aiohttp
import tiktoken import tiktoken
from llm_server import opts from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
def tokenize(prompt: str, backend_url: str) -> int: def tokenize(prompt: str, backend_url: str) -> int:
assert backend_url assert backend_url
if not prompt: if not prompt:
# The tokenizers have issues when the prompt is None.
return 0 return 0
tokenizer = tiktoken.get_encoding("cl100k_base")
token_limit = cluster_config.get_backend(backend_url)['model_config']['max_position_embeddings']
# First we tokenize it locally to determine if it's worth sending it to the backend. async def run():
initial_estimate = len(tokenizer.encode(prompt)) tokenizer = tiktoken.get_encoding("cl100k_base")
if initial_estimate <= token_limit + 200:
try: async def send_chunk(chunk):
r = requests.post(f'{backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) try:
j = r.json() async with session.post(f'{backend_url}/tokenize', json={'input': chunk}, verify_ssl=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) as response:
return j['length'] j = await response.json()
except Exception as e: return j['length']
print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}') except Exception as e:
return len(tokenizer.encode(prompt)) + 10 print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}')
else: return len(tokenizer.encode(chunk)) + 10
# If the result was greater than our context size, return the estimate.
# We won't be sending it through the backend so it does't need to be accurage. chunk_size = 300
return initial_estimate chunks = [prompt[i:i + chunk_size] for i in range(0, len(prompt), chunk_size)]
async with aiohttp.ClientSession() as session:
tasks = [send_chunk(chunk) for chunk in chunks]
lengths = await asyncio.gather(*tasks)
return sum(lengths)
return asyncio.run(run())

View File

@ -13,7 +13,7 @@ class OobaRequestHandler(RequestHandler):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def handle_request(self): def handle_request(self, return_ok: bool = True):
assert not self.used assert not self.used
request_valid, invalid_response = self.validate_request() request_valid, invalid_response = self.validate_request()
@ -25,14 +25,19 @@ class OobaRequestHandler(RequestHandler):
llm_request = {**self.parameters, 'prompt': prompt} llm_request = {**self.parameters, 'prompt': prompt}
_, backend_response = self.generate_response(llm_request) _, backend_response = self.generate_response(llm_request)
return backend_response if return_ok:
# Always return 200 so ST displays our error messages
return backend_response[0], 200
else:
# The OpenAI route needs to detect 429 errors.
return backend_response
def handle_ratelimited(self, do_log: bool = True): def handle_ratelimited(self, do_log: bool = True):
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.' msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
backend_response = self.handle_error(msg) backend_response = self.handle_error(msg)
if do_log: if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True) log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
return backend_response[0], 200 # We only return the response from handle_error(), not the error code return backend_response[0], 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'

View File

@ -8,10 +8,11 @@ from llm_server.custom_redis import redis
from . import openai_bp from . import openai_bp
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler from ..openai_request_handler import OpenAIRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
from ... import opts from ... import opts
from ...database.database import log_prompt from ...database.database import log_prompt
from ...llm.generator import generator from ...llm.generator import generator
from ...llm.openai.oai_to_vllm import oai_to_vllm from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
@ -24,11 +25,6 @@ def openai_chat_completions():
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
else: else:
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body) handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body)
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends
raise NotImplementedError
if not request_json_body.get('stream'): if not request_json_body.get('stream'):
try: try:
return handler.handle_request() return handler.handle_request()
@ -37,30 +33,51 @@ def openai_chat_completions():
return 'Internal server error', 500 return 'Internal server error', 500
else: else:
if not opts.enable_streaming: if not opts.enable_streaming:
# TODO: return a proper OAI error message return 'DISABLED', 401
return 'disabled', 401
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
if opts.openai_silent_trim: if opts.openai_silent_trim:
handler.request_json_body['messages'] = trim_messages_to_fit(request_json_body['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
else:
handler.prompt = transform_messages_to_prompt(handler.request.json['messages'])
response_status_code = 0 response_status_code = 0
start_time = time.time() start_time = time.time()
request_valid, invalid_response = handler.validate_request() request_valid, invalid_response = handler.validate_request()
if not request_valid: if not request_valid:
return invalid_response return invalid_response
else: else:
if opts.openai_silent_trim:
oai_messages = trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
else:
oai_messages = handler.request.json['messages']
handler.prompt = transform_messages_to_prompt(oai_messages)
handler.parameters = oai_to_vllm(handler.parameters, hashes=True, mode=handler.cluster_backend_info['mode'])
msg_to_backend = { msg_to_backend = {
**handler.parameters, **handler.parameters,
'prompt': handler.prompt, 'prompt': handler.prompt,
'stream': True, 'stream': True,
} }
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url)
if not event:
log_prompt(
handler.client_ip,
handler.token,
handler.prompt,
None,
None,
handler.parameters,
request.headers,
response_status_code,
request.url,
handler.backend_url,
)
return handler.handle_ratelimited()
# Wait for a worker to get our request and discard it.
_, _, _ = event.wait()
try: try:
response = generator(msg_to_backend, handler.backend_url) response = generator(msg_to_backend, handler.backend_url)
r_headers = dict(request.headers) r_headers = dict(request.headers)
@ -69,57 +86,61 @@ def openai_chat_completions():
oai_string = generate_oai_string(30) oai_string = generate_oai_string(30)
def generate(): def generate():
generated_text = '' try:
partial_response = b'' generated_text = ''
for chunk in response.iter_content(chunk_size=1): partial_response = b''
partial_response += chunk for chunk in response.iter_content(chunk_size=1):
if partial_response.endswith(b'\x00'): partial_response += chunk
json_strs = partial_response.split(b'\x00') if partial_response.endswith(b'\x00'):
for json_str in json_strs: json_strs = partial_response.split(b'\x00')
if json_str: for json_str in json_strs:
try: if json_str:
json_obj = json.loads(json_str.decode()) try:
new = json_obj['text'][0].split(handler.prompt + generated_text)[1] json_obj = json.loads(json_str.decode())
generated_text = generated_text + new new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
except IndexError: generated_text = generated_text + new
# ???? except IndexError:
continue # ????
continue
data = { data = {
"id": f"chatcmpl-{oai_string}", "id": f"chatcmpl-{oai_string}",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"created": int(time.time()), "created": int(time.time()),
"model": model, "model": model,
"choices": [ "choices": [
{ {
"index": 0, "index": 0,
"delta": { "delta": {
"content": new "content": new
}, },
"finish_reason": None "finish_reason": None
} }
] ]
} }
yield f'data: {json.dumps(data)}\n\n' yield f'data: {json.dumps(data)}\n\n'
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
yield 'data: [DONE]\n\n' log_prompt(
end_time = time.time() handler.client_ip,
elapsed_time = end_time - start_time handler.token,
handler.prompt,
log_prompt( generated_text,
handler.client_ip, elapsed_time,
handler.token, handler.parameters,
handler.prompt, r_headers,
generated_text, response_status_code,
elapsed_time, r_url,
handler.parameters, handler.backend_url,
r_headers, )
response_status_code, finally:
r_url, # The worker incremented it, we'll decrement it.
handler.backend_url, decrement_ip_count(handler.client_ip, 'processing_ips')
) decr_active_workers(handler.selected_model, handler.backend_url)
return Response(generate(), mimetype='text/event-stream') return Response(generate(), mimetype='text/event-stream')
except: except Exception:
# TODO: simulate OAI here traceback.print_exc()
raise Exception return 'INTERNAL SERVER', 500

View File

@ -8,6 +8,7 @@ from llm_server.custom_redis import redis
from . import openai_bp from . import openai_bp
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
from ... import opts from ... import opts
from ...database.database import log_prompt from ...database.database import log_prompt
from ...llm import get_token_count from ...llm import get_token_count
@ -24,80 +25,98 @@ def openai_completions():
if not request_valid_json or not request_json_body.get('prompt'): if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else: else:
try: handler = OobaRequestHandler(incoming_request=request)
handler = OobaRequestHandler(incoming_request=request)
if handler.cluster_backend_info['mode'] != 'vllm': if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends # TODO: implement other backends
raise NotImplementedError raise NotImplementedError
invalid_oai_err_msg = validate_oai(handler.request_json_body) invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg: if invalid_oai_err_msg:
return invalid_oai_err_msg return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode']) handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
# Convert parameters to the selected backend type if opts.openai_silent_trim:
if opts.openai_silent_trim: handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) else:
else: # The handle_request() call below will load the prompt so we don't have
# The handle_request() call below will load the prompt so we don't have # to do anything else here.
# to do anything else here. pass
pass
if not request_json_body.get('stream'): if not request_json_body.get('stream'):
response, status_code = handler.handle_request() response, status_code = handler.handle_request(return_ok=False)
if status_code != 200: if status_code == 429:
return status_code return handler.handle_ratelimited()
output = response.json['results'][0]['text'] output = response.json['results'][0]['text']
# TODO: async/await # TODO: async/await
prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url) prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url)
response_tokens = get_token_count(output, handler.backend_url) response_tokens = get_token_count(output, handler.backend_url)
running_model = redis.get('running_model', 'ERROR', dtype=str) running_model = redis.get('running_model', 'ERROR', dtype=str)
response = jsonify({ response = jsonify({
"id": f"cmpl-{generate_oai_string(30)}", "id": f"cmpl-{generate_oai_string(30)}",
"object": "text_completion", "object": "text_completion",
"created": int(time.time()), "created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'), "model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
"choices": [ "choices": [
{ {
"text": output, "text": output,
"index": 0, "index": 0,
"logprobs": None, "logprobs": None,
"finish_reason": "stop" "finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
} }
}) ],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
})
stats = redis.get('proxy_stats', dtype=dict) stats = redis.get('proxy_stats', dtype=dict)
if stats: if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response, 200 return response, 200
else:
if not opts.enable_streaming:
return 'DISABLED', 401
response_status_code = 0
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
return invalid_response
else: else:
if not opts.enable_streaming: handler.prompt = handler.request_json_body['prompt']
# TODO: return a proper OAI error message msg_to_backend = {
return 'disabled', 401 **handler.parameters,
'prompt': handler.prompt,
'stream': True,
}
response_status_code = 0 # Add a dummy event to the queue and wait for it to reach a worker
start_time = time.time() event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url)
if not event:
log_prompt(
handler.client_ip,
handler.token,
handler.prompt,
None,
None,
handler.parameters,
request.headers,
response_status_code,
request.url,
handler.backend_url,
)
return handler.handle_ratelimited()
request_valid, invalid_response = handler.validate_request() # Wait for a worker to get our request and discard it.
if not request_valid: _, _, _ = event.wait()
# TODO: simulate OAI here
raise Exception('TODO: simulate OAI here') try:
else:
handler.prompt = handler.request_json_body['prompt']
msg_to_backend = {
**handler.parameters,
'prompt': handler.prompt,
'stream': True,
}
response = generator(msg_to_backend, handler.backend_url) response = generator(msg_to_backend, handler.backend_url)
r_headers = dict(request.headers) r_headers = dict(request.headers)
r_url = request.url r_url = request.url
@ -105,57 +124,61 @@ def openai_completions():
oai_string = generate_oai_string(30) oai_string = generate_oai_string(30)
def generate(): def generate():
generated_text = '' try:
partial_response = b'' generated_text = ''
for chunk in response.iter_content(chunk_size=1): partial_response = b''
partial_response += chunk for chunk in response.iter_content(chunk_size=1):
if partial_response.endswith(b'\x00'): partial_response += chunk
json_strs = partial_response.split(b'\x00') if partial_response.endswith(b'\x00'):
for json_str in json_strs: json_strs = partial_response.split(b'\x00')
if json_str: for json_str in json_strs:
try: if json_str:
json_obj = json.loads(json_str.decode()) try:
new = json_obj['text'][0].split(handler.prompt + generated_text)[1] json_obj = json.loads(json_str.decode())
generated_text = generated_text + new new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
except IndexError: generated_text = generated_text + new
# ???? except IndexError:
continue # ????
continue
data = { data = {
"id": f"chatcmpl-{oai_string}", "id": f"cmpl-{oai_string}",
"object": "text_completion", "object": "text_completion",
"created": int(time.time()), "created": int(time.time()),
"model": model, "model": model,
"choices": [ "choices": [
{ {
"index": 0, "index": 0,
"delta": { "delta": {
"content": new "content": new
}, },
"finish_reason": None "finish_reason": None
} }
] ]
} }
yield f'data: {json.dumps(data)}\n\n' yield f'data: {json.dumps(data)}\n\n'
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
yield 'data: [DONE]\n\n' log_prompt(
end_time = time.time() handler.client_ip,
elapsed_time = end_time - start_time handler.token,
handler.prompt,
log_prompt( generated_text,
handler.client_ip, elapsed_time,
handler.token, handler.parameters,
handler.prompt, r_headers,
generated_text, response_status_code,
elapsed_time, r_url,
handler.parameters, handler.backend_url,
r_headers, )
response_status_code, finally:
r_url, # The worker incremented it, we'll decrement it.
handler.backend_url, decrement_ip_count(handler.client_ip, 'processing_ips')
) decr_active_workers(handler.selected_model, handler.backend_url)
return Response(generate(), mimetype='text/event-stream') return Response(generate(), mimetype='text/event-stream')
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
return 'Internal Server Error', 500 return 'INTERNAL SERVER', 500

View File

@ -10,8 +10,9 @@ from flask import Response, jsonify, make_response
import llm_server import llm_server
from llm_server import opts from llm_server import opts
from llm_server.cluster.model_choices import get_model_choices
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.database.database import is_api_key_moderated from llm_server.database.database import is_api_key_moderated, log_prompt
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from llm_server.routes.request_handler import RequestHandler from llm_server.routes.request_handler import RequestHandler
@ -70,9 +71,24 @@ class OpenAIRequestHandler(RequestHandler):
return backend_response, backend_response_status_code return backend_response, backend_response_status_code
def handle_ratelimited(self, do_log: bool = True): def handle_ratelimited(self, do_log: bool = True):
# TODO: return a simulated OpenAI error message _, default_backend_info = get_model_choices()
# Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another. w = int(default_backend_info['estimated_wait']) if default_backend_info['estimated_wait'] > 0 else 2
return 'Ratelimited', 429 response = jsonify({
"error": {
"message": "Rate limit reached on tokens per min. Limit: 10000 / min. Please try again in 6s. Contact us through our help center at help.openai.com if you continue to have issues.",
"type": "rate_limit_exceeded",
"param": None,
"code": None
}
})
response.headers['x-ratelimit-limit-requests'] = '2'
response.headers['x-ratelimit-remaining-requests'] = '0'
response.headers['x-ratelimit-reset-requests'] = f"{w}s"
if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
return response, 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
return jsonify({ return jsonify({

View File

@ -209,7 +209,7 @@ class RequestHandler:
if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0: if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0:
return False return False
else: else:
print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} queued + processing.') print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} already queued/processing.')
return True return True
def handle_request(self) -> Tuple[flask.Response, int]: def handle_request(self) -> Tuple[flask.Response, int]:

View File

@ -115,6 +115,10 @@ def do_stream(ws, model_name):
err_msg = r.json['results'][0]['text'] err_msg = r.json['results'][0]['text']
send_err_and_quit(err_msg) send_err_and_quit(err_msg)
return return
# Wait for a worker to get our request and discard it.
_, _, _ = event.wait()
try: try:
response = generator(llm_request, handler.backend_url) response = generator(llm_request, handler.backend_url)
if not response: if not response:

View File

@ -6,6 +6,7 @@ from llm_server.custom_redis import flask_cache
from . import bp from . import bp
from ... import opts from ... import opts
from ...cluster.backend import get_a_cluster_backend, get_backends_from_model, is_valid_model from ...cluster.backend import get_a_cluster_backend, get_backends_from_model, is_valid_model
from ...cluster.cluster_config import cluster_config
@bp.route('/v1/model', methods=['GET']) @bp.route('/v1/model', methods=['GET'])

View File

@ -21,8 +21,10 @@ def worker():
incr_active_workers(selected_model, backend_url) incr_active_workers(selected_model, backend_url)
if not request_json_body: if not request_json_body:
# This was a dummy request from the websocket handler. # This was a dummy request from the websocket handlers.
# We're going to let the websocket handler decrement processing_ips and active_gen_workers. # We're going to let the websocket handler decrement processing_ips and active_gen_workers.
event = DataEvent(event_id)
event.set((True, None, None))
continue continue
try: try:

View File

@ -14,3 +14,5 @@ urllib3~=2.0.4
flask-sock==0.6.0 flask-sock==0.6.0
gunicorn==21.2.0 gunicorn==21.2.0
redis==5.0.1 redis==5.0.1
aiohttp==3.8.5
asyncio==3.4.3