This commit is contained in:
Cyberes 2024-04-28 14:26:20 -06:00
parent c300d0c240
commit 05baf3e30e
6 changed files with 53 additions and 16 deletions

1
.gitignore vendored
View File

@ -18,7 +18,6 @@ dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/

View File

@ -1,3 +1,3 @@
# text-summarization-api
_An API that uses AI to summarize text._
_A very simple HTTP API to summarize text using AI._

0
lib/__init__.py Normal file
View File

8
lib/time.py Normal file
View File

@ -0,0 +1,8 @@
from datetime import datetime
def calculate_elapsed_time(start_time: datetime):
diff = datetime.now() - start_time
hours, remainder = divmod(diff.total_seconds(), 3600)
minutes, seconds = divmod(remainder, 60)
return f'{int(hours):02}:{int(minutes):02}:{int(seconds):02}'

View File

@ -3,6 +3,9 @@ import logging
import os
import signal
import sys
import time
import traceback
from datetime import datetime
from pathlib import Path
import nltk
@ -11,10 +14,20 @@ from starlette.responses import JSONResponse
from starlette.routing import Route
from transformers import AutoTokenizer
from transformers import pipeline
from transformers.utils import logging as tl
logging.basicConfig()
logger = logging.getLogger('MAIN')
logger.setLevel(logging.INFO)
from lib.time import calculate_elapsed_time
logger = logging.getLogger("uvicorn.info")
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)
tl.set_verbosity_error()
tl.captureWarnings(True)
async def generate(request):
@ -23,34 +36,49 @@ async def generate(request):
if not string:
return JSONResponse({"error": "invalid request"}, status_code=400)
response_q = asyncio.Queue()
await request.app.model_queue.put((string, response_q))
await request.app.model_queue.put((string, response_q, request.client.host))
output = await response_q.get()
return JSONResponse(output)
async def server_loop(queue):
model = os.getenv("MODEL", default="facebook/bart-large-cnn")
max_length = int(os.getenv("MAXLEN", default=200)) - 5 # add some buffer
max_input_length = int(os.getenv("MAX_IN_LEN", default=200))
max_new_tokens = int(os.getenv("MAX_NEW_LEN", default=200))
min_length = int(os.getenv("MIN_LEN", default=100))
summarizer = pipeline("summarization", model=model, device=0)
tokenizer = AutoTokenizer.from_pretrained(model) # load tokenizer from the model
logger.info('Worker started!')
while True:
(string, response_q) = await queue.get()
(string, response_q, client_ip) = await queue.get()
start_time = datetime.now()
# Automatically adjust the string and the task's max length
# Automatically adjust the input string to the max length.
sentences = nltk.sent_tokenize(string)
tokens = []
for sentence in sentences:
sentence_tokens = tokenizer.encode(sentence, truncation=False)
if len(tokens) + len(sentence_tokens) < max_length:
if len(tokens) + len(sentence_tokens) < max_input_length:
tokens.extend(sentence_tokens)
else:
break
try:
out = summarizer(tokenizer.decode(tokens), max_length=len(tokens), min_length=100, do_sample=False)
await response_q.put(out)
except RuntimeError:
err = False
for i in range(5):
try:
out = summarizer(tokenizer.decode(tokens), max_new_tokens=max_new_tokens, min_length=min_length, do_sample=False)
await response_q.put(out)
break
except RuntimeError:
err = True
logger.error(f"Failed to generate summary for {string}: {traceback.format_exc()}")
time.sleep(1)
elapsed = calculate_elapsed_time(start_time)
logger.info(f"{client_ip} - {len(tokens)} tokens - {elapsed}")
if err:
# Just kill ourselves and let gunicorn restart the worker.
os.kill(os.getpid(), signal.SIGINT)

View File

@ -7,8 +7,10 @@ Type=simple
User=flask
WorkingDirectory=/srv/text-summarization-api
Environment="MODEL=/storage/models/facebook/bart-large-cnn"
Environment="MAXLEN=1000"
ExecStart=/srv/text-summarization-api/venv/bin/gunicorn -w 4 -k uvicorn.workers.UvicornWorker server:app -b 0.0.0.0:8000 --access-logfile - --error-logfile -
Environment="MAX_IN_LEN=1000"
Environment="MAX_NEW_LEN=1000"
Environment="MIN_LEN=100"
ExecStart=/srv/text-summarization-api/venv/bin/gunicorn -w 4 -k uvicorn.workers.UvicornWorker server:app -b 0.0.0.0:8000 --access-logfile - --error-logfile - --log-level info
Restart=on-failure
RestartSec=5s
SyslogIdentifier=TextSummarizationAPI