actually validate prompt length lol
This commit is contained in:
parent
3100b0a924
commit
77edbe779c
|
@ -2,6 +2,9 @@ from typing import Tuple, Union
|
||||||
|
|
||||||
import flask
|
import flask
|
||||||
|
|
||||||
|
from llm_server import opts
|
||||||
|
from llm_server.database import tokenizer
|
||||||
|
|
||||||
|
|
||||||
class LLMBackend:
|
class LLMBackend:
|
||||||
_default_params: dict
|
_default_params: dict
|
||||||
|
@ -26,3 +29,11 @@ class LLMBackend:
|
||||||
|
|
||||||
def validate_request(self, parameters: dict) -> Tuple[bool, Union[str, None]]:
|
def validate_request(self, parameters: dict) -> Tuple[bool, Union[str, None]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_prompt(prompt: str) -> Tuple[bool, Union[str, None]]:
|
||||||
|
prompt_len = len(tokenizer.encode(prompt))
|
||||||
|
print(prompt_len, opts.context_size)
|
||||||
|
if prompt_len > opts.context_size - 10: # Our tokenizer isn't 100% accurate so we cut it down a bit. TODO: add a tokenizer endpoint to VLLM
|
||||||
|
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}). Please lower your context size'
|
||||||
|
return True, None
|
||||||
|
|
|
@ -6,7 +6,6 @@ from vllm import SamplingParams
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.database import log_prompt
|
from llm_server.database import log_prompt
|
||||||
from llm_server.llm.llm_backend import LLMBackend
|
from llm_server.llm.llm_backend import LLMBackend
|
||||||
from llm_server.routes.helpers.http import validate_json
|
|
||||||
|
|
||||||
|
|
||||||
class VLLMBackend(LLMBackend):
|
class VLLMBackend(LLMBackend):
|
||||||
|
|
|
@ -68,12 +68,19 @@ class RequestHandler:
|
||||||
combined_error_message = ', '.join(error_messages)
|
combined_error_message = ', '.join(error_messages)
|
||||||
backend_response = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
backend_response = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
||||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
||||||
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
|
|
||||||
return False, self.handle_error(backend_response)
|
return False, self.handle_error(backend_response)
|
||||||
return True, (None, 0)
|
return True, (None, 0)
|
||||||
|
|
||||||
def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]:
|
def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]:
|
||||||
|
prompt = llm_request['prompt']
|
||||||
if not self.is_client_ratelimited():
|
if not self.is_client_ratelimited():
|
||||||
|
# Validate the prompt right before submission since the backend handler may have changed something.
|
||||||
|
prompt_valid, invalid_prompt_err_msg = self.backend.validate_prompt(prompt)
|
||||||
|
if not prompt_valid:
|
||||||
|
backend_response = format_sillytavern_err(f'Validation Error: {invalid_prompt_err_msg}.', 'error')
|
||||||
|
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
||||||
|
return (False, None, None, 0), self.handle_error(backend_response)
|
||||||
|
|
||||||
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
|
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
|
||||||
else:
|
else:
|
||||||
event = None
|
event = None
|
||||||
|
@ -81,8 +88,6 @@ class RequestHandler:
|
||||||
if not event:
|
if not event:
|
||||||
return (False, None, None, 0), self.handle_ratelimited()
|
return (False, None, None, 0), self.handle_ratelimited()
|
||||||
|
|
||||||
prompt = llm_request['prompt']
|
|
||||||
|
|
||||||
success, response, error_msg = event.wait()
|
success, response, error_msg = event.wait()
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_time = end_time - self.start_time
|
elapsed_time = end_time - self.start_time
|
||||||
|
|
Reference in New Issue