diff --git a/.gitignore b/.gitignore index a25fca9..5fc60d8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .idea +punkt/ # ---> Python # Byte-compiled / optimized / DLL files @@ -161,4 +162,3 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ - diff --git a/punkt-download.py b/punkt-download.py new file mode 100644 index 0000000..0b00ca8 --- /dev/null +++ b/punkt-download.py @@ -0,0 +1,9 @@ +import os +from pathlib import Path + +import nltk + +script_dir = Path(os.path.dirname(os.path.realpath(__file__))) +target = script_dir / 'punkt' +print('Downloading punkt to:', target) +nltk.download('punkt', download_dir=target) diff --git a/requirements.txt b/requirements.txt index 2bb9933..83119b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,5 @@ starlette==0.37.2 uvicorn==0.29.0 transformers==4.40.1 torch==2.3.0 +gunicorn==22.0.0 +nltk==3.8.1 \ No newline at end of file diff --git a/server.py b/server.py index 9243d14..ef7ae44 100644 --- a/server.py +++ b/server.py @@ -1,22 +1,27 @@ -""" -NOTE: This API server is used only for demonstrating usage of AsyncEngine -and simple performance benchmarks. It is not intended for production use. -For production use, we recommend using our OpenAI compatible server. -We are also not going to accept PRs modifying this file, please -change `vllm/entrypoints/openai/api_server.py` instead. -""" import asyncio +import logging import os +import signal +import sys +from pathlib import Path +import nltk from starlette.applications import Starlette from starlette.responses import JSONResponse from starlette.routing import Route +from transformers import AutoTokenizer from transformers import pipeline +logging.basicConfig() +logger = logging.getLogger('MAIN') +logger.setLevel(logging.INFO) + async def generate(request): - payload = await request.body() - string = payload.decode("utf-8") + data = await request.json() + string = data.get('text') + if not string: + return JSONResponse({"error": "invalid request"}, status_code=400) response_q = asyncio.Queue() await request.app.model_queue.put((string, response_q)) output = await response_q.get() @@ -25,11 +30,29 @@ async def generate(request): async def server_loop(queue): model = os.getenv("MODEL", default="facebook/bart-large-cnn") + max_length = int(os.getenv("MAXLEN", default=200)) - 5 # add some buffer summarizer = pipeline("summarization", model=model, device=0) + tokenizer = AutoTokenizer.from_pretrained(model) # load tokenizer from the model + while True: (string, response_q) = await queue.get() - out = summarizer(string, max_length=200, min_length=100, do_sample=False) - await response_q.put(out) + + # Automatically adjust the string and the task's max length + sentences = nltk.sent_tokenize(string) + tokens = [] + for sentence in sentences: + sentence_tokens = tokenizer.encode(sentence, truncation=False) + if len(tokens) + len(sentence_tokens) < max_length: + tokens.extend(sentence_tokens) + else: + break + + try: + out = summarizer(tokenizer.decode(tokens), max_length=len(tokens), min_length=100, do_sample=False) + await response_q.put(out) + except RuntimeError: + # Just kill ourselves and let gunicorn restart the worker. + os.kill(os.getpid(), signal.SIGINT) app = Starlette( @@ -41,6 +64,11 @@ app = Starlette( @app.on_event("startup") async def startup_event(): + script_dir = Path(os.path.dirname(os.path.realpath(__file__))) + if not (script_dir / 'punkt' / 'tokenizers' / 'punkt').is_dir(): + logger.critical(f'Punkt not found at "{script_dir}/punkt". Please run "./punkt-download.py" first.') + sys.exit(1) + nltk.download('punkt', download_dir=script_dir / 'punkt') q = asyncio.Queue() app.model_queue = q - asyncio.create_task(server_loop(q)) + await asyncio.create_task(server_loop(q)) diff --git a/text-summarization-api.service b/text-summarization-api.service index 1fa7d7b..9c4365c 100644 --- a/text-summarization-api.service +++ b/text-summarization-api.service @@ -7,7 +7,8 @@ Type=simple User=flask WorkingDirectory=/srv/text-summarization-api Environment="MODEL=/storage/models/facebook/bart-large-cnn" -ExecStart=/srv/text-summarization-api/venv/bin/uvicorn server:app --host 0.0.0.0 +Environment="MAXLEN=1000" +ExecStart=/srv/text-summarization-api/venv/bin/gunicorn -w 4 -k uvicorn.workers.UvicornWorker server:app -b 0.0.0.0:8000 Restart=on-failure RestartSec=5s