2023-09-11 20:47:19 -06:00
|
|
|
"""
|
|
|
|
This file is used by the worker that processes requests.
|
|
|
|
"""
|
|
|
|
import json
|
|
|
|
import time
|
|
|
|
from uuid import uuid4
|
|
|
|
|
|
|
|
import requests
|
|
|
|
|
2023-09-20 20:30:31 -06:00
|
|
|
import llm_server
|
2023-09-11 20:47:19 -06:00
|
|
|
from llm_server import opts
|
2023-09-28 18:40:24 -06:00
|
|
|
from llm_server.custom_redis import redis
|
2023-09-11 20:47:19 -06:00
|
|
|
|
|
|
|
|
|
|
|
# TODO: make the VLMM backend return TPS and time elapsed
|
|
|
|
# https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py
|
|
|
|
|
|
|
|
def prepare_json(json_data: dict):
|
|
|
|
# logit_bias is not currently supported
|
2023-09-12 01:04:11 -06:00
|
|
|
# del json_data['logit_bias']
|
2023-09-24 13:02:30 -06:00
|
|
|
|
|
|
|
# Convert back to VLLM.
|
|
|
|
json_data['max_tokens'] = json_data.pop('max_new_tokens')
|
2023-09-11 20:47:19 -06:00
|
|
|
return json_data
|
|
|
|
|
|
|
|
|
|
|
|
def transform_prompt_to_text(prompt: list):
|
|
|
|
text = ''
|
|
|
|
for item in prompt:
|
|
|
|
text += item['content'] + '\n'
|
|
|
|
return text.strip('\n')
|
|
|
|
|
|
|
|
|
2023-09-30 19:41:50 -06:00
|
|
|
def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10):
|
2023-09-11 20:47:19 -06:00
|
|
|
try:
|
2023-09-30 19:41:50 -06:00
|
|
|
r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout)
|
2023-09-14 14:05:50 -06:00
|
|
|
except requests.exceptions.ReadTimeout:
|
2023-09-30 19:41:50 -06:00
|
|
|
# print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
|
2023-09-14 14:05:50 -06:00
|
|
|
return False, None, 'Request to backend timed out'
|
2023-09-11 20:47:19 -06:00
|
|
|
except Exception as e:
|
2023-09-30 19:41:50 -06:00
|
|
|
# print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
|
2023-09-24 13:02:30 -06:00
|
|
|
return False, None, 'Request to backend encountered error'
|
2023-09-14 14:05:50 -06:00
|
|
|
if r.status_code != 200:
|
2023-09-30 19:41:50 -06:00
|
|
|
# print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
|
2023-09-14 14:05:50 -06:00
|
|
|
return False, r, f'Backend returned {r.status_code}'
|
2023-09-12 01:04:11 -06:00
|
|
|
return True, r, None
|
2023-09-11 20:47:19 -06:00
|
|
|
|
|
|
|
|
2023-09-30 19:41:50 -06:00
|
|
|
def generate(json_data: dict, cluster_backend, timeout: int = None):
|
2023-09-11 20:47:19 -06:00
|
|
|
if json_data.get('stream'):
|
2023-09-27 14:36:49 -06:00
|
|
|
try:
|
2023-09-30 19:41:50 -06:00
|
|
|
return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout)
|
2023-09-27 14:36:49 -06:00
|
|
|
except Exception as e:
|
2023-09-30 19:41:50 -06:00
|
|
|
return False
|
2023-09-11 20:47:19 -06:00
|
|
|
else:
|
2023-09-30 19:41:50 -06:00
|
|
|
return handle_blocking_request(json_data, cluster_backend, timeout=timeout)
|