move where the vllm model is set

This commit is contained in:
Cyberes 2023-09-11 21:05:22 -06:00
parent 4c9d543eab
commit 747d838138
2 changed files with 6 additions and 4 deletions

View File

@ -106,10 +106,6 @@ def handle_blocking_request(json_data: dict):
def generate(json_data: dict): def generate(json_data: dict):
full_model_path = redis.get('full_model_path')
if not full_model_path:
raise Exception
json_data['model'] = full_model_path.decode()
if json_data.get('stream'): if json_data.get('stream'):
raise Exception('streaming not implemented') raise Exception('streaming not implemented')
else: else:

View File

@ -99,6 +99,12 @@ class OobaRequestHandler:
if opts.mode == 'hf-textgen' and self.parameters.get('typical_p', 0) > 0.998: if opts.mode == 'hf-textgen' and self.parameters.get('typical_p', 0) > 0.998:
self.request_json_body['typical_p'] = 0.998 self.request_json_body['typical_p'] = 0.998
if opts.mode == 'vllm':
full_model_path = redis.get('full_model_path')
if not full_model_path:
raise Exception
self.request_json_body['model'] = full_model_path.decode()
request_valid, invalid_request_err_msg = self.validate_request() request_valid, invalid_request_err_msg = self.validate_request()
params_valid, invalid_params_err_msg = self.backend.validate_params(self.parameters) params_valid, invalid_params_err_msg = self.backend.validate_params(self.parameters)
if not request_valid or not params_valid: if not request_valid or not params_valid: