From 0a51ad30af29f7670d61c20538972c810f4002e2 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Sun, 28 Apr 2024 14:37:43 -0600 Subject: [PATCH] work --- server.py | 60 ++++++++++++++++++++++++++++++++----------------------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/server.py b/server.py index 75debb8..daf7c11 100644 --- a/server.py +++ b/server.py @@ -22,7 +22,7 @@ logger = logging.getLogger("uvicorn.info") logger.setLevel(logging.DEBUG) ch = logging.StreamHandler() ch.setLevel(logging.DEBUG) -formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') ch.setFormatter(formatter) logger.addHandler(ch) @@ -38,48 +38,58 @@ async def generate(request): response_q = asyncio.Queue() await request.app.model_queue.put((string, response_q, request.client.host)) output = await response_q.get() - return JSONResponse(output) + return JSONResponse({'summary_text': output}) async def server_loop(queue): + boot_time = datetime.now() model = os.getenv("MODEL", default="facebook/bart-large-cnn") max_input_length = int(os.getenv("MAX_IN_LEN", default=200)) max_new_tokens = int(os.getenv("MAX_NEW_LEN", default=200)) min_length = int(os.getenv("MIN_LEN", default=100)) summarizer = pipeline("summarization", model=model, device=0) tokenizer = AutoTokenizer.from_pretrained(model) # load tokenizer from the model - logger.info('Worker started!') + logger.info(f'Worker started in {calculate_elapsed_time(boot_time)}') + del boot_time while True: (string, response_q, client_ip) = await queue.get() start_time = datetime.now() - - # Automatically adjust the input string to the max length. - sentences = nltk.sent_tokenize(string) - tokens = [] - for sentence in sentences: - sentence_tokens = tokenizer.encode(sentence, truncation=False) - if len(tokens) + len(sentence_tokens) < max_input_length: - tokens.extend(sentence_tokens) - else: - break - + out = '' err = False - for i in range(5): - try: - out = summarizer(tokenizer.decode(tokens), max_new_tokens=max_new_tokens, min_length=min_length, do_sample=False) - await response_q.put(out) - break - except RuntimeError: - err = True - logger.error(f"Failed to generate summary for {string}: {traceback.format_exc()}") - time.sleep(1) + tokens = [] - elapsed = calculate_elapsed_time(start_time) - logger.info(f"{client_ip} - {len(tokens)} tokens - {elapsed}") + try: + # Automatically adjust the input string to the max length. + sentences = nltk.sent_tokenize(string) + tokens = [] + for sentence in sentences: + sentence_tokens = tokenizer.encode(sentence, truncation=False) + if len(tokens) + len(sentence_tokens) < max_input_length: + tokens.extend(sentence_tokens) + else: + break + + for i in range(5): + try: + out = list(summarizer(tokenizer.decode(tokens), max_new_tokens=max_new_tokens, min_length=min_length, do_sample=False))[0]['summary_text'] + await response_q.put(out) + err = False + break + except: + err = True + logger.error(f'Error generating summary for "{string}": {traceback.format_exc()}') + time.sleep(1) + except: + err = True + logger.error(f'Failed to generate summary for "{string}": {traceback.format_exc()}') + + out_tokens = len(tokenizer.encode(out, truncation=False)) + logger.info(f"{client_ip} - {len(tokens)} in tokens - {out_tokens} out tokens - {calculate_elapsed_time(start_time)} elapsed") if err: # Just kill ourselves and let gunicorn restart the worker. + logger.error('Worker rebooting due to exception.') os.kill(os.getpid(), signal.SIGINT)