2023-09-20 20:30:31 -06:00
import traceback
2023-09-12 16:40:09 -06:00
from typing import Tuple , Union
2023-09-12 01:04:11 -06:00
2023-09-20 20:30:31 -06:00
import requests
2023-09-11 20:47:19 -06:00
from flask import jsonify
from vllm import SamplingParams
2023-09-20 20:30:31 -06:00
import llm_server
2023-09-12 16:40:09 -06:00
from llm_server import opts
2023-09-20 20:30:31 -06:00
from llm_server . database . database import log_prompt
2023-09-11 20:47:19 -06:00
from llm_server . llm . llm_backend import LLMBackend
class VLLMBackend ( LLMBackend ) :
2023-09-14 17:38:20 -06:00
_default_params = vars ( SamplingParams ( ) )
2023-09-12 16:40:09 -06:00
2023-09-14 14:05:50 -06:00
def handle_response ( self , success , request , response_json_body , response_status_code , client_ip , token , prompt : str , elapsed_time , parameters , headers ) :
if len ( response_json_body . get ( ' text ' , [ ] ) ) :
# Does vllm return the prompt and the response together???
backend_response = response_json_body [ ' text ' ] [ 0 ] . split ( prompt ) [ 1 ] . strip ( ' ' ) . strip ( ' \n ' )
2023-09-11 20:47:19 -06:00
else :
2023-09-14 14:05:50 -06:00
# Failsafe
backend_response = ' '
log_prompt ( ip = client_ip , token = token , prompt = prompt , response = backend_response , gen_time = elapsed_time , parameters = parameters , headers = headers , backend_response_code = response_status_code , request_url = request . url ,
response_tokens = response_json_body . get ( ' details ' , { } ) . get ( ' generated_tokens ' ) )
return jsonify ( { ' results ' : [ { ' text ' : backend_response } ] } ) , 200
2023-09-12 01:04:11 -06:00
2023-09-12 10:30:45 -06:00
def get_parameters ( self , parameters ) - > Tuple [ dict | None , str | None ] :
2023-09-12 01:04:11 -06:00
try :
2023-09-14 17:38:20 -06:00
# top_k == -1 means disabled
top_k = parameters . get ( ' top_k ' , self . _default_params [ ' top_k ' ] )
if top_k < = 0 :
top_k = - 1
2023-09-12 01:04:11 -06:00
sampling_params = SamplingParams (
2023-09-14 17:38:20 -06:00
temperature = parameters . get ( ' temperature ' , self . _default_params [ ' temperature ' ] ) ,
top_p = parameters . get ( ' top_p ' , self . _default_params [ ' top_p ' ] ) ,
top_k = top_k ,
2023-09-12 16:40:09 -06:00
use_beam_search = True if parameters . get ( ' num_beams ' , 0 ) > 1 else False ,
2023-09-14 17:38:20 -06:00
stop = parameters . get ( ' stopping_strings ' , self . _default_params [ ' stop ' ] ) ,
2023-09-12 01:04:11 -06:00
ignore_eos = parameters . get ( ' ban_eos_token ' , False ) ,
2023-09-14 17:38:20 -06:00
max_tokens = parameters . get ( ' max_new_tokens ' , self . _default_params [ ' max_tokens ' ] )
2023-09-12 01:04:11 -06:00
)
except ValueError as e :
2023-09-12 10:30:45 -06:00
return None , str ( e ) . strip ( ' . ' )
2023-09-12 01:04:11 -06:00
return vars ( sampling_params ) , None
2023-09-12 16:40:09 -06:00
def validate_request ( self , parameters ) - > ( bool , Union [ str , None ] ) :
if parameters . get ( ' max_new_tokens ' , 0 ) > opts . max_new_tokens :
return False , f ' `max_new_tokens` must be less than or equal to { opts . max_new_tokens } '
return True , None
2023-09-20 20:30:31 -06:00
# def tokenize(self, prompt):
# try:
# r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
# j = r.json()
# return j['length']
# except:
# # Fall back to whatever the superclass is doing.
# print(traceback.format_exc())
# return super().tokenize(prompt)
def validate_prompt ( self , prompt : str ) - > Tuple [ bool , Union [ str , None ] ] :
prompt_len = llm_server . llm . tokenizer ( prompt )
if prompt_len > opts . context_size :
return False , f ' Token indices sequence length is longer than the specified maximum sequence length for this model ( { prompt_len } > { opts . context_size } ). Please lower your context size '
return True , None