Merge cluster to master #3
|
@ -77,3 +77,18 @@ def validate_oai(parameters):
|
||||||
|
|
||||||
if parameters.get('max_tokens', 2) < 1:
|
if parameters.get('max_tokens', 2) < 1:
|
||||||
return format_oai_err(f"{parameters['max_tokens']} is less than the minimum of 1 - 'max_tokens'")
|
return format_oai_err(f"{parameters['max_tokens']} is less than the minimum of 1 - 'max_tokens'")
|
||||||
|
|
||||||
|
|
||||||
|
def return_invalid_model_err(requested_model: str):
|
||||||
|
if requested_model:
|
||||||
|
msg = f"The model `{requested_model}` does not exist"
|
||||||
|
else:
|
||||||
|
msg = "The requested model does not exist"
|
||||||
|
return jsonify({
|
||||||
|
"error": {
|
||||||
|
"message": msg,
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"param": None,
|
||||||
|
"code": "model_not_found"
|
||||||
|
}
|
||||||
|
}), 404
|
||||||
|
|
|
@ -12,7 +12,7 @@ from ..queue import priority_queue
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...database.log_to_db import log_to_db
|
from ...database.log_to_db import log_to_db
|
||||||
from ...llm.generator import generator
|
from ...llm.generator import generator
|
||||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
|
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
|
||||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,6 +27,11 @@ def openai_chat_completions(model_name=None):
|
||||||
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
||||||
else:
|
else:
|
||||||
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
|
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
|
||||||
|
if handler.offline:
|
||||||
|
msg = return_invalid_model_err(model_name)
|
||||||
|
print(msg)
|
||||||
|
return handler.handle_error(msg)
|
||||||
|
|
||||||
if not request_json_body.get('stream'):
|
if not request_json_body.get('stream'):
|
||||||
try:
|
try:
|
||||||
invalid_oai_err_msg = validate_oai(request_json_body)
|
invalid_oai_err_msg = validate_oai(request_json_body)
|
||||||
|
|
|
@ -13,7 +13,7 @@ from ... import opts
|
||||||
from ...database.log_to_db import log_to_db
|
from ...database.log_to_db import log_to_db
|
||||||
from ...llm import get_token_count
|
from ...llm import get_token_count
|
||||||
from ...llm.generator import generator
|
from ...llm.generator import generator
|
||||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
|
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
|
||||||
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
|
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,6 +27,10 @@ def openai_completions(model_name=None):
|
||||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||||
else:
|
else:
|
||||||
handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
|
handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
|
||||||
|
if handler.offline:
|
||||||
|
msg = return_invalid_model_err(model_name)
|
||||||
|
print(msg)
|
||||||
|
return handler.handle_error(msg)
|
||||||
|
|
||||||
if handler.cluster_backend_info['mode'] != 'vllm':
|
if handler.cluster_backend_info['mode'] != 'vllm':
|
||||||
# TODO: implement other backends
|
# TODO: implement other backends
|
||||||
|
|
|
@ -11,10 +11,10 @@ from flask import Response, jsonify, make_response
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.cluster.backend import get_model_choices
|
from llm_server.cluster.backend import get_model_choices
|
||||||
from llm_server.custom_redis import redis
|
from llm_server.custom_redis import redis
|
||||||
from llm_server.database.database import is_api_key_moderated, do_db_log
|
from llm_server.database.database import is_api_key_moderated
|
||||||
from llm_server.database.log_to_db import log_to_db
|
from llm_server.database.log_to_db import log_to_db
|
||||||
from llm_server.llm import get_token_count
|
from llm_server.llm import get_token_count
|
||||||
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
|
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
|
||||||
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||||
from llm_server.routes.request_handler import RequestHandler
|
from llm_server.routes.request_handler import RequestHandler
|
||||||
from llm_server.workers.moderator import add_moderation_task, get_results
|
from llm_server.workers.moderator import add_moderation_task, get_results
|
||||||
|
@ -27,6 +27,10 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
|
|
||||||
def handle_request(self) -> Tuple[flask.Response, int]:
|
def handle_request(self) -> Tuple[flask.Response, int]:
|
||||||
assert not self.used
|
assert not self.used
|
||||||
|
if self.offline:
|
||||||
|
msg = return_invalid_model_err(self.selected_model)
|
||||||
|
print(msg)
|
||||||
|
return self.handle_error(msg)
|
||||||
|
|
||||||
if opts.openai_silent_trim:
|
if opts.openai_silent_trim:
|
||||||
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
|
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
|
||||||
|
@ -63,7 +67,7 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
|
|
||||||
if not self.prompt:
|
if not self.prompt:
|
||||||
# TODO: format this as an openai error message
|
# TODO: format this as an openai error message
|
||||||
return 'Invalid prompt', 400
|
return Response('Invalid prompt'), 400
|
||||||
|
|
||||||
llm_request = {**self.parameters, 'prompt': self.prompt}
|
llm_request = {**self.parameters, 'prompt': self.prompt}
|
||||||
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from llm_server.cluster.cluster_config import cluster_config
|
from llm_server.cluster.cluster_config import cluster_config
|
||||||
|
@ -41,6 +42,8 @@ def worker(backend_url):
|
||||||
success, response, error_msg = generator(request_json_body, backend_url)
|
success, response, error_msg = generator(request_json_body, backend_url)
|
||||||
event = DataEvent(event_id)
|
event = DataEvent(event_id)
|
||||||
event.set((success, response, error_msg))
|
event.set((success, response, error_msg))
|
||||||
|
except:
|
||||||
|
traceback.print_exc()
|
||||||
finally:
|
finally:
|
||||||
decrement_ip_count(client_ip, 'processing_ips')
|
decrement_ip_count(client_ip, 'processing_ips')
|
||||||
decr_active_workers(selected_model, backend_url)
|
decr_active_workers(selected_model, backend_url)
|
||||||
|
|
Reference in New Issue