move where the vllm model is set
This commit is contained in:
parent
4c9d543eab
commit
747d838138
|
@ -106,10 +106,6 @@ def handle_blocking_request(json_data: dict):
|
||||||
|
|
||||||
|
|
||||||
def generate(json_data: dict):
|
def generate(json_data: dict):
|
||||||
full_model_path = redis.get('full_model_path')
|
|
||||||
if not full_model_path:
|
|
||||||
raise Exception
|
|
||||||
json_data['model'] = full_model_path.decode()
|
|
||||||
if json_data.get('stream'):
|
if json_data.get('stream'):
|
||||||
raise Exception('streaming not implemented')
|
raise Exception('streaming not implemented')
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -99,6 +99,12 @@ class OobaRequestHandler:
|
||||||
if opts.mode == 'hf-textgen' and self.parameters.get('typical_p', 0) > 0.998:
|
if opts.mode == 'hf-textgen' and self.parameters.get('typical_p', 0) > 0.998:
|
||||||
self.request_json_body['typical_p'] = 0.998
|
self.request_json_body['typical_p'] = 0.998
|
||||||
|
|
||||||
|
if opts.mode == 'vllm':
|
||||||
|
full_model_path = redis.get('full_model_path')
|
||||||
|
if not full_model_path:
|
||||||
|
raise Exception
|
||||||
|
self.request_json_body['model'] = full_model_path.decode()
|
||||||
|
|
||||||
request_valid, invalid_request_err_msg = self.validate_request()
|
request_valid, invalid_request_err_msg = self.validate_request()
|
||||||
params_valid, invalid_params_err_msg = self.backend.validate_params(self.parameters)
|
params_valid, invalid_params_err_msg = self.backend.validate_params(self.parameters)
|
||||||
if not request_valid or not params_valid:
|
if not request_valid or not params_valid:
|
||||||
|
|
Reference in New Issue