112 lines
3.8 KiB
Python
112 lines
3.8 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
import signal
|
|
import sys
|
|
import time
|
|
import traceback
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import nltk
|
|
from starlette.applications import Starlette
|
|
from starlette.responses import JSONResponse
|
|
from starlette.routing import Route
|
|
from transformers import AutoTokenizer
|
|
from transformers import pipeline
|
|
from transformers.utils import logging as tl
|
|
|
|
from lib.time import calculate_elapsed_time
|
|
|
|
logger = logging.getLogger("uvicorn.info")
|
|
logger.setLevel(logging.DEBUG)
|
|
ch = logging.StreamHandler()
|
|
ch.setLevel(logging.DEBUG)
|
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
|
ch.setFormatter(formatter)
|
|
logger.addHandler(ch)
|
|
|
|
tl.set_verbosity_error()
|
|
tl.captureWarnings(True)
|
|
|
|
|
|
async def generate(request):
|
|
data = await request.json()
|
|
string = data.get('text')
|
|
if not string:
|
|
return JSONResponse({"error": "invalid request"}, status_code=400)
|
|
response_q = asyncio.Queue()
|
|
await request.app.model_queue.put((string, response_q, request.client.host))
|
|
output = await response_q.get()
|
|
return JSONResponse({'summary_text': output})
|
|
|
|
|
|
async def server_loop(queue):
|
|
boot_time = datetime.now()
|
|
model = os.getenv("MODEL", default="facebook/bart-large-cnn")
|
|
max_input_length = int(os.getenv("MAX_IN_LEN", default=200))
|
|
max_new_tokens = int(os.getenv("MAX_NEW_LEN", default=200))
|
|
min_length = int(os.getenv("MIN_LEN", default=100))
|
|
summarizer = pipeline("summarization", model=model, device=0)
|
|
tokenizer = AutoTokenizer.from_pretrained(model) # load tokenizer from the model
|
|
logger.info(f'Worker started in {calculate_elapsed_time(boot_time)}')
|
|
del boot_time
|
|
|
|
while True:
|
|
(string, response_q, client_ip) = await queue.get()
|
|
start_time = datetime.now()
|
|
out = ''
|
|
err = False
|
|
tokens = []
|
|
|
|
try:
|
|
# Automatically adjust the input string to the max length.
|
|
sentences = nltk.sent_tokenize(string)
|
|
tokens = []
|
|
for sentence in sentences:
|
|
sentence_tokens = tokenizer.encode(sentence, truncation=False)
|
|
if len(tokens) + len(sentence_tokens) < max_input_length:
|
|
tokens.extend(sentence_tokens)
|
|
else:
|
|
break
|
|
|
|
for i in range(5):
|
|
try:
|
|
out = list(summarizer(tokenizer.decode(tokens), max_new_tokens=max_new_tokens, min_length=min_length, do_sample=False))[0]['summary_text']
|
|
await response_q.put(out)
|
|
err = False
|
|
break
|
|
except:
|
|
err = True
|
|
logger.error(f'Error generating summary for "{string}": {traceback.format_exc()}')
|
|
time.sleep(1)
|
|
except:
|
|
err = True
|
|
logger.error(f'Failed to generate summary for "{string}": {traceback.format_exc()}')
|
|
|
|
out_tokens = len(tokenizer.encode(out, truncation=False))
|
|
logger.info(f"{client_ip} - {len(tokens)} in tokens - {out_tokens} out tokens - {calculate_elapsed_time(start_time)} elapsed")
|
|
|
|
if err:
|
|
# Just kill ourselves and let gunicorn restart the worker.
|
|
logger.error('Worker rebooting due to exception.')
|
|
os.kill(os.getpid(), signal.SIGINT)
|
|
|
|
|
|
app = Starlette(
|
|
routes=[
|
|
Route("/", generate, methods=["POST"]),
|
|
],
|
|
)
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
script_dir = Path(os.path.dirname(os.path.realpath(__file__)))
|
|
if not (script_dir / 'punkt' / 'tokenizers' / 'punkt').is_dir():
|
|
logger.critical(f'Punkt not found at "{script_dir}/punkt". Please run "./punkt-download.py" first.')
|
|
sys.exit(1)
|
|
q = asyncio.Queue()
|
|
app.model_queue = q
|
|
asyncio.create_task(server_loop(q))
|