from typing import Tuple from flask import jsonify from vllm import SamplingParams from llm_server.database.log_to_db import log_to_db from llm_server.llm.llm_backend import LLMBackend class VLLMBackend(LLMBackend): _default_params = vars(SamplingParams()) def handle_response(self, success, request, response_json_body, response_status_code, client_ip, token, prompt: str, elapsed_time, parameters, headers): if len(response_json_body.get('text', [])): # Does vllm return the prompt and the response together??? backend_response = response_json_body['text'][0].split(prompt)[1].strip(' ').strip('\n') else: # Failsafe backend_response = '' log_to_db(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), backend_url=self.backend_url) return jsonify({'results': [{'text': backend_response}]}), 200 def get_parameters(self, parameters) -> Tuple[dict | None, str | None]: """ Convert the Oobabooga parameters to VLLM and validate them. :param parameters: :return: """ try: # top_k == -1 means disabled top_k = parameters.get('top_k', self._default_params['top_k']) if top_k <= 0: top_k = -1 # We call the internal VLLM `SamplingParams` class to validate the input parameters. # Parameters from Oobabooga don't line up here exactly, so we have to shuffle some things around. # TODO: support more params sampling_params = SamplingParams( temperature=parameters.get('temperature', self._default_params['temperature']), top_p=parameters.get('top_p', self._default_params['top_p']), top_k=top_k, use_beam_search=True if parameters.get('num_beams', 0) > 1 else False, stop=list(set(parameters.get('stopping_strings') or parameters.get('stop', self._default_params['stop']))), ignore_eos=parameters.get('ban_eos_token', False), max_tokens=parameters.get('max_new_tokens') or parameters.get('max_tokens', self._default_params['max_tokens']), presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']), frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty']), length_penalty=parameters.get('length_penalty', self._default_params['length_penalty']), early_stopping=parameters.get('early_stopping', self._default_params['early_stopping']) ) except ValueError as e: # `SamplingParams` will return a pretty error message. Send that back to the caller. return None, str(e).strip('.') # We use max_new_tokens throughout this program, so rename the variable. result = vars(sampling_params) result['max_new_tokens'] = result.pop('max_tokens') return result, None