from typing import Tuple, Union from flask import jsonify from vllm import SamplingParams from llm_server import opts from llm_server.database import log_prompt from llm_server.llm.llm_backend import LLMBackend from llm_server.routes.helpers.http import validate_json class VLLMBackend(LLMBackend): default_params = vars(SamplingParams()) def handle_response(self, success, request, response_json_body, response_status_code, client_ip, token, prompt: str, elapsed_time, parameters, headers): if len(response_json_body.get('text', [])): # Does vllm return the prompt and the response together??? backend_response = response_json_body['text'][0].split(prompt)[1].strip(' ').strip('\n') else: # Failsafe backend_response = '' log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens')) return jsonify({'results': [{'text': backend_response}]}), 200 def get_parameters(self, parameters) -> Tuple[dict | None, str | None]: try: sampling_params = SamplingParams( temperature=parameters.get('temperature', self.default_params['temperature']), top_p=parameters.get('top_p', self.default_params['top_p']), top_k=parameters.get('top_k', self.default_params['top_k']), use_beam_search=True if parameters.get('num_beams', 0) > 1 else False, stop=parameters.get('stopping_strings', self.default_params['stop']), ignore_eos=parameters.get('ban_eos_token', False), max_tokens=parameters.get('max_new_tokens', self.default_params['max_tokens']) ) except ValueError as e: return None, str(e).strip('.') return vars(sampling_params), None def validate_request(self, parameters) -> (bool, Union[str, None]): if parameters.get('max_new_tokens', 0) > opts.max_new_tokens: return False, f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}' return True, None