actually validate prompt length lol
This commit is contained in:
parent
3100b0a924
commit
77edbe779c
|
@ -2,6 +2,9 @@ from typing import Tuple, Union
|
|||
|
||||
import flask
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import tokenizer
|
||||
|
||||
|
||||
class LLMBackend:
|
||||
_default_params: dict
|
||||
|
@ -26,3 +29,11 @@ class LLMBackend:
|
|||
|
||||
def validate_request(self, parameters: dict) -> Tuple[bool, Union[str, None]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def validate_prompt(prompt: str) -> Tuple[bool, Union[str, None]]:
|
||||
prompt_len = len(tokenizer.encode(prompt))
|
||||
print(prompt_len, opts.context_size)
|
||||
if prompt_len > opts.context_size - 10: # Our tokenizer isn't 100% accurate so we cut it down a bit. TODO: add a tokenizer endpoint to VLLM
|
||||
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}). Please lower your context size'
|
||||
return True, None
|
||||
|
|
|
@ -6,7 +6,6 @@ from vllm import SamplingParams
|
|||
from llm_server import opts
|
||||
from llm_server.database import log_prompt
|
||||
from llm_server.llm.llm_backend import LLMBackend
|
||||
from llm_server.routes.helpers.http import validate_json
|
||||
|
||||
|
||||
class VLLMBackend(LLMBackend):
|
||||
|
|
|
@ -68,12 +68,19 @@ class RequestHandler:
|
|||
combined_error_message = ', '.join(error_messages)
|
||||
backend_response = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
||||
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
|
||||
return False, self.handle_error(backend_response)
|
||||
return True, (None, 0)
|
||||
|
||||
def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]:
|
||||
prompt = llm_request['prompt']
|
||||
if not self.is_client_ratelimited():
|
||||
# Validate the prompt right before submission since the backend handler may have changed something.
|
||||
prompt_valid, invalid_prompt_err_msg = self.backend.validate_prompt(prompt)
|
||||
if not prompt_valid:
|
||||
backend_response = format_sillytavern_err(f'Validation Error: {invalid_prompt_err_msg}.', 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
||||
return (False, None, None, 0), self.handle_error(backend_response)
|
||||
|
||||
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
|
||||
else:
|
||||
event = None
|
||||
|
@ -81,8 +88,6 @@ class RequestHandler:
|
|||
if not event:
|
||||
return (False, None, None, 0), self.handle_ratelimited()
|
||||
|
||||
prompt = llm_request['prompt']
|
||||
|
||||
success, response, error_msg = event.wait()
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - self.start_time
|
||||
|
|
Reference in New Issue