text-summarization-api/server.py

112 lines
3.8 KiB
Python

import asyncio
import logging
import os
import signal
import sys
import time
import traceback
from datetime import datetime
from pathlib import Path
import nltk
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
from transformers import AutoTokenizer
from transformers import pipeline
from transformers.utils import logging as tl
from lib.time import calculate_elapsed_time
logger = logging.getLogger("uvicorn.info")
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)
tl.set_verbosity_error()
tl.captureWarnings(True)
async def generate(request):
data = await request.json()
string = data.get('text')
if not string:
return JSONResponse({"error": "invalid request"}, status_code=400)
response_q = asyncio.Queue()
await request.app.model_queue.put((string, response_q, request.client.host))
output = await response_q.get()
return JSONResponse({'summary_text': output})
async def server_loop(queue):
boot_time = datetime.now()
model = os.getenv("MODEL", default="facebook/bart-large-cnn")
max_input_length = int(os.getenv("MAX_IN_LEN", default=200))
max_new_tokens = int(os.getenv("MAX_NEW_LEN", default=200))
min_length = int(os.getenv("MIN_LEN", default=100))
summarizer = pipeline("summarization", model=model, device=0)
tokenizer = AutoTokenizer.from_pretrained(model) # load tokenizer from the model
logger.info(f'Worker started in {calculate_elapsed_time(boot_time)}')
del boot_time
while True:
(string, response_q, client_ip) = await queue.get()
start_time = datetime.now()
out = ''
err = False
tokens = []
try:
# Automatically adjust the input string to the max length.
sentences = nltk.sent_tokenize(string)
tokens = []
for sentence in sentences:
sentence_tokens = tokenizer.encode(sentence, truncation=False)
if len(tokens) + len(sentence_tokens) < max_input_length:
tokens.extend(sentence_tokens)
else:
break
for i in range(5):
try:
out = list(summarizer(tokenizer.decode(tokens), max_new_tokens=max_new_tokens, min_length=min_length, do_sample=False))[0]['summary_text']
await response_q.put(out)
err = False
break
except:
err = True
logger.error(f'Error generating summary for "{string}": {traceback.format_exc()}')
time.sleep(1)
except:
err = True
logger.error(f'Failed to generate summary for "{string}": {traceback.format_exc()}')
out_tokens = len(tokenizer.encode(out, truncation=False))
logger.info(f"{client_ip} - {len(tokens)} in tokens - {out_tokens} out tokens - {calculate_elapsed_time(start_time)} elapsed")
if err:
# Just kill ourselves and let gunicorn restart the worker.
logger.error('Worker rebooting due to exception.')
os.kill(os.getpid(), signal.SIGINT)
app = Starlette(
routes=[
Route("/", generate, methods=["POST"]),
],
)
@app.on_event("startup")
async def startup_event():
script_dir = Path(os.path.dirname(os.path.realpath(__file__)))
if not (script_dir / 'punkt' / 'tokenizers' / 'punkt').is_dir():
logger.critical(f'Punkt not found at "{script_dir}/punkt". Please run "./punkt-download.py" first.')
sys.exit(1)
q = asyncio.Queue()
app.model_queue = q
asyncio.create_task(server_loop(q))