""" This file is used by the worker that processes requests. """ import json import time from uuid import uuid4 import requests import llm_server from llm_server import opts from llm_server.routes.cache import redis # TODO: make the VLMM backend return TPS and time elapsed # https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py def prepare_json(json_data: dict): # logit_bias is not currently supported # del json_data['logit_bias'] # Convert back to VLLM. json_data['max_tokens'] = json_data.pop('max_new_tokens') return json_data def transform_to_text(json_request, api_response): """ This is to convert a streaming request to a non-streamed request. Don't think this is nessesary. :param json_request: :param api_response: :return: """ prompt = transform_prompt_to_text(json_request['messages']) text = '' finish_reason = None for line in api_response.split('\n'): if line.startswith('data:'): try: data = json.loads(line[5:].strip()) except json.decoder.JSONDecodeError: break if 'choices' in data: for choice in data['choices']: if 'delta' in choice and 'content' in choice['delta']: text += choice['delta']['content'] if data['choices'][0]['finish_reason']: finish_reason = data['choices'][0]['finish_reason'] prompt_tokens = len(llm_server.llm.get_token_count(prompt)) completion_tokens = len(llm_server.llm.get_token_count(text)) running_model = redis.get('running_model', str, 'ERROR') # https://platform.openai.com/docs/api-reference/making-requests?lang=python return { "id": str(uuid4()), "object": "chat.completion", "created": int(time.time()), "model": running_model, "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens }, "choices": [ { "message": { "role": "assistant", "content": text }, "finish_reason": finish_reason, "index": 0 } ] } def transform_prompt_to_text(prompt: list): text = '' for item in prompt: text += item['content'] + '\n' return text.strip('\n') def handle_blocking_request(json_data: dict): try: r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) except requests.exceptions.ReadTimeout: print(f'Failed to reach VLLM inference endpoint - request to backend timed out') return False, None, 'Request to backend timed out' except Exception as e: print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}') return False, None, 'Request to backend encountered error' if r.status_code != 200: print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}') return False, r, f'Backend returned {r.status_code}' return True, r, None def generate(json_data: dict): if json_data.get('stream'): try: return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) except Exception as e: print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}') else: return handle_blocking_request(json_data)