work
This commit is contained in:
parent
c300d0c240
commit
05baf3e30e
|
@ -18,7 +18,6 @@ dist/
|
|||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
# text-summarization-api
|
||||
|
||||
_An API that uses AI to summarize text._
|
||||
_A very simple HTTP API to summarize text using AI._
|
|
@ -0,0 +1,8 @@
|
|||
from datetime import datetime
|
||||
|
||||
|
||||
def calculate_elapsed_time(start_time: datetime):
|
||||
diff = datetime.now() - start_time
|
||||
hours, remainder = divmod(diff.total_seconds(), 3600)
|
||||
minutes, seconds = divmod(remainder, 60)
|
||||
return f'{int(hours):02}:{int(minutes):02}:{int(seconds):02}'
|
52
server.py
52
server.py
|
@ -3,6 +3,9 @@ import logging
|
|||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import nltk
|
||||
|
@ -11,10 +14,20 @@ from starlette.responses import JSONResponse
|
|||
from starlette.routing import Route
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import pipeline
|
||||
from transformers.utils import logging as tl
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger('MAIN')
|
||||
logger.setLevel(logging.INFO)
|
||||
from lib.time import calculate_elapsed_time
|
||||
|
||||
logger = logging.getLogger("uvicorn.info")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.DEBUG)
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
|
||||
tl.set_verbosity_error()
|
||||
tl.captureWarnings(True)
|
||||
|
||||
|
||||
async def generate(request):
|
||||
|
@ -23,34 +36,49 @@ async def generate(request):
|
|||
if not string:
|
||||
return JSONResponse({"error": "invalid request"}, status_code=400)
|
||||
response_q = asyncio.Queue()
|
||||
await request.app.model_queue.put((string, response_q))
|
||||
await request.app.model_queue.put((string, response_q, request.client.host))
|
||||
output = await response_q.get()
|
||||
return JSONResponse(output)
|
||||
|
||||
|
||||
async def server_loop(queue):
|
||||
model = os.getenv("MODEL", default="facebook/bart-large-cnn")
|
||||
max_length = int(os.getenv("MAXLEN", default=200)) - 5 # add some buffer
|
||||
max_input_length = int(os.getenv("MAX_IN_LEN", default=200))
|
||||
max_new_tokens = int(os.getenv("MAX_NEW_LEN", default=200))
|
||||
min_length = int(os.getenv("MIN_LEN", default=100))
|
||||
summarizer = pipeline("summarization", model=model, device=0)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model) # load tokenizer from the model
|
||||
logger.info('Worker started!')
|
||||
|
||||
while True:
|
||||
(string, response_q) = await queue.get()
|
||||
(string, response_q, client_ip) = await queue.get()
|
||||
start_time = datetime.now()
|
||||
|
||||
# Automatically adjust the string and the task's max length
|
||||
# Automatically adjust the input string to the max length.
|
||||
sentences = nltk.sent_tokenize(string)
|
||||
tokens = []
|
||||
for sentence in sentences:
|
||||
sentence_tokens = tokenizer.encode(sentence, truncation=False)
|
||||
if len(tokens) + len(sentence_tokens) < max_length:
|
||||
if len(tokens) + len(sentence_tokens) < max_input_length:
|
||||
tokens.extend(sentence_tokens)
|
||||
else:
|
||||
break
|
||||
|
||||
try:
|
||||
out = summarizer(tokenizer.decode(tokens), max_length=len(tokens), min_length=100, do_sample=False)
|
||||
await response_q.put(out)
|
||||
except RuntimeError:
|
||||
err = False
|
||||
for i in range(5):
|
||||
try:
|
||||
out = summarizer(tokenizer.decode(tokens), max_new_tokens=max_new_tokens, min_length=min_length, do_sample=False)
|
||||
await response_q.put(out)
|
||||
break
|
||||
except RuntimeError:
|
||||
err = True
|
||||
logger.error(f"Failed to generate summary for {string}: {traceback.format_exc()}")
|
||||
time.sleep(1)
|
||||
|
||||
elapsed = calculate_elapsed_time(start_time)
|
||||
logger.info(f"{client_ip} - {len(tokens)} tokens - {elapsed}")
|
||||
|
||||
if err:
|
||||
# Just kill ourselves and let gunicorn restart the worker.
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
|
||||
|
|
|
@ -7,8 +7,10 @@ Type=simple
|
|||
User=flask
|
||||
WorkingDirectory=/srv/text-summarization-api
|
||||
Environment="MODEL=/storage/models/facebook/bart-large-cnn"
|
||||
Environment="MAXLEN=1000"
|
||||
ExecStart=/srv/text-summarization-api/venv/bin/gunicorn -w 4 -k uvicorn.workers.UvicornWorker server:app -b 0.0.0.0:8000 --access-logfile - --error-logfile -
|
||||
Environment="MAX_IN_LEN=1000"
|
||||
Environment="MAX_NEW_LEN=1000"
|
||||
Environment="MIN_LEN=100"
|
||||
ExecStart=/srv/text-summarization-api/venv/bin/gunicorn -w 4 -k uvicorn.workers.UvicornWorker server:app -b 0.0.0.0:8000 --access-logfile - --error-logfile - --log-level info
|
||||
Restart=on-failure
|
||||
RestartSec=5s
|
||||
SyslogIdentifier=TextSummarizationAPI
|
||||
|
|
Loading…
Reference in New Issue