actually validate prompt length lol

This commit is contained in:
Cyberes 2023-09-14 18:31:13 -06:00
parent 3100b0a924
commit 77edbe779c
3 changed files with 19 additions and 4 deletions

View File

@ -2,6 +2,9 @@ from typing import Tuple, Union
import flask import flask
from llm_server import opts
from llm_server.database import tokenizer
class LLMBackend: class LLMBackend:
_default_params: dict _default_params: dict
@ -26,3 +29,11 @@ class LLMBackend:
def validate_request(self, parameters: dict) -> Tuple[bool, Union[str, None]]: def validate_request(self, parameters: dict) -> Tuple[bool, Union[str, None]]:
raise NotImplementedError raise NotImplementedError
@staticmethod
def validate_prompt(prompt: str) -> Tuple[bool, Union[str, None]]:
prompt_len = len(tokenizer.encode(prompt))
print(prompt_len, opts.context_size)
if prompt_len > opts.context_size - 10: # Our tokenizer isn't 100% accurate so we cut it down a bit. TODO: add a tokenizer endpoint to VLLM
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}). Please lower your context size'
return True, None

View File

@ -6,7 +6,6 @@ from vllm import SamplingParams
from llm_server import opts from llm_server import opts
from llm_server.database import log_prompt from llm_server.database import log_prompt
from llm_server.llm.llm_backend import LLMBackend from llm_server.llm.llm_backend import LLMBackend
from llm_server.routes.helpers.http import validate_json
class VLLMBackend(LLMBackend): class VLLMBackend(LLMBackend):

View File

@ -68,12 +68,19 @@ class RequestHandler:
combined_error_message = ', '.join(error_messages) combined_error_message = ', '.join(error_messages)
backend_response = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error') backend_response = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
return False, self.handle_error(backend_response) return False, self.handle_error(backend_response)
return True, (None, 0) return True, (None, 0)
def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]: def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]:
prompt = llm_request['prompt']
if not self.is_client_ratelimited(): if not self.is_client_ratelimited():
# Validate the prompt right before submission since the backend handler may have changed something.
prompt_valid, invalid_prompt_err_msg = self.backend.validate_prompt(prompt)
if not prompt_valid:
backend_response = format_sillytavern_err(f'Validation Error: {invalid_prompt_err_msg}.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
return (False, None, None, 0), self.handle_error(backend_response)
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority) event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
else: else:
event = None event = None
@ -81,8 +88,6 @@ class RequestHandler:
if not event: if not event:
return (False, None, None, 0), self.handle_ratelimited() return (False, None, None, 0), self.handle_ratelimited()
prompt = llm_request['prompt']
success, response, error_msg = event.wait() success, response, error_msg = event.wait()
end_time = time.time() end_time = time.time()
elapsed_time = end_time - self.start_time elapsed_time = end_time - self.start_time