diff --git a/llm_server/cluster/backend.py b/llm_server/cluster/backend.py index 2a7edc3..62dc2c5 100644 --- a/llm_server/cluster/backend.py +++ b/llm_server/cluster/backend.py @@ -6,6 +6,7 @@ from llm_server.cluster.stores import redis_running_models from llm_server.custom_redis import redis from llm_server.llm.generator import generator from llm_server.llm.info import get_info +from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.queue import priority_queue from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers_model @@ -33,11 +34,15 @@ def is_valid_model(model_name: str): def test_backend(backend_url: str, test_prompt: bool = False): backend_info = cluster_config.get_backend(backend_url) if test_prompt: - data = { - "prompt": "Test prompt", + handler = VLLMBackend(backend_url) + parameters, _ = handler.get_parameters({ "stream": False, "temperature": 0, "max_new_tokens": 3, + }) + data = { + 'prompt': 'test prompt', + **parameters } try: success, response, err = generator(data, backend_url, timeout=10) diff --git a/llm_server/workers/inferencer.py b/llm_server/workers/inferencer.py index a1dd749..0738c1b 100644 --- a/llm_server/workers/inferencer.py +++ b/llm_server/workers/inferencer.py @@ -66,7 +66,6 @@ def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str json_obj = json.loads(json_str.decode()) new = json_obj['text'][0].split(prompt + generated_text)[1] generated_text = generated_text + new - print(new) except IndexError: # ???? continue