This commit is contained in:
Cyberes 2024-04-28 14:37:43 -06:00
parent 05baf3e30e
commit 0a51ad30af
1 changed files with 35 additions and 25 deletions

View File

@ -22,7 +22,7 @@ logger = logging.getLogger("uvicorn.info")
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)
@ -38,48 +38,58 @@ async def generate(request):
response_q = asyncio.Queue()
await request.app.model_queue.put((string, response_q, request.client.host))
output = await response_q.get()
return JSONResponse(output)
return JSONResponse({'summary_text': output})
async def server_loop(queue):
boot_time = datetime.now()
model = os.getenv("MODEL", default="facebook/bart-large-cnn")
max_input_length = int(os.getenv("MAX_IN_LEN", default=200))
max_new_tokens = int(os.getenv("MAX_NEW_LEN", default=200))
min_length = int(os.getenv("MIN_LEN", default=100))
summarizer = pipeline("summarization", model=model, device=0)
tokenizer = AutoTokenizer.from_pretrained(model) # load tokenizer from the model
logger.info('Worker started!')
logger.info(f'Worker started in {calculate_elapsed_time(boot_time)}')
del boot_time
while True:
(string, response_q, client_ip) = await queue.get()
start_time = datetime.now()
# Automatically adjust the input string to the max length.
sentences = nltk.sent_tokenize(string)
tokens = []
for sentence in sentences:
sentence_tokens = tokenizer.encode(sentence, truncation=False)
if len(tokens) + len(sentence_tokens) < max_input_length:
tokens.extend(sentence_tokens)
else:
break
out = ''
err = False
for i in range(5):
try:
out = summarizer(tokenizer.decode(tokens), max_new_tokens=max_new_tokens, min_length=min_length, do_sample=False)
await response_q.put(out)
break
except RuntimeError:
err = True
logger.error(f"Failed to generate summary for {string}: {traceback.format_exc()}")
time.sleep(1)
tokens = []
elapsed = calculate_elapsed_time(start_time)
logger.info(f"{client_ip} - {len(tokens)} tokens - {elapsed}")
try:
# Automatically adjust the input string to the max length.
sentences = nltk.sent_tokenize(string)
tokens = []
for sentence in sentences:
sentence_tokens = tokenizer.encode(sentence, truncation=False)
if len(tokens) + len(sentence_tokens) < max_input_length:
tokens.extend(sentence_tokens)
else:
break
for i in range(5):
try:
out = list(summarizer(tokenizer.decode(tokens), max_new_tokens=max_new_tokens, min_length=min_length, do_sample=False))[0]['summary_text']
await response_q.put(out)
err = False
break
except:
err = True
logger.error(f'Error generating summary for "{string}": {traceback.format_exc()}')
time.sleep(1)
except:
err = True
logger.error(f'Failed to generate summary for "{string}": {traceback.format_exc()}')
out_tokens = len(tokenizer.encode(out, truncation=False))
logger.info(f"{client_ip} - {len(tokens)} in tokens - {out_tokens} out tokens - {calculate_elapsed_time(start_time)} elapsed")
if err:
# Just kill ourselves and let gunicorn restart the worker.
logger.error('Worker rebooting due to exception.')
os.kill(os.getpid(), signal.SIGINT)