add kosmo-2 caption script and update diffusers and transformers to latest

This commit is contained in:
Victor Hall 2023-11-02 21:47:50 -04:00
parent 2f2dd4c1f2
commit 3150f7d299
3 changed files with 103 additions and 4 deletions

99
caption_kosmos2.py Normal file
View File

@ -0,0 +1,99 @@
"""
Copyright [2022-2023] Victor C Hall
Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License.
You may obtain a copy of the License at
https://www.gnu.org/licenses/agpl-3.0.en.html
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import io
import argparse
import time
from PIL import Image
from pynvml import *
from transformers import AutoProcessor, AutoModelForVision2Seq
GROUNDING = "<grounding>"
SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
def get_gpu_memory_map():
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(0)
info = nvmlDeviceGetMemoryInfo(handle)
nvmlShutdown()
return info.used/1024/1024
def remove_starting_string(a, b):
if b.startswith(a):
return b[len(a):] # Remove string A from the beginning of string B
else:
return b
def main(args):
for root, dirs, files in os.walk(args.data_root):
for file in files:
#get file extension
ext = os.path.splitext(file)[1]
if ext.lower() in SUPPORTED_EXT:
start_time = time.time()
full_file_path = os.path.join(root, file)
image = Image.open(full_file_path)
model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224")
processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
full_file_path = os.path.join(root, file)
image = Image.open(full_file_path)
inputs = processor(text=GROUNDING+args.prompt, images=image, return_tensors="pt")
generated_ids = model.generate(
pixel_values=inputs["pixel_values"],
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
image_embeds=None,
image_embeds_position_mask=inputs["image_embeds_position_mask"],
use_cache=True,
max_new_tokens=args.max_new_tokens,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
processed_text, entities = processor.post_process_generation(generated_text) # remove remaining special tokens to get just the caption and entities
if not args.keep_prompt:
processed_text = remove_starting_string(args.prompt, processed_text)
print(f"File: {image}, Generated caption: {processed_text}")
name = os.path.splitext(full_file_path)[0]
if not os.path.exists(f"{name}.txt") or args.over_write:
with open(f"{name}.txt", "w") as f:
f.write(processed_text)
if args.save_entities and (not os.path.exists(f"{name}.ent") or args.over_write):
with open(f"{name}.ent", "w") as entities_file:
entities_file.write(entities)
if __name__ == "__main__":
print("Kosmos-2 captioning script")
parser = argparse.ArgumentParser()
parser.add_argument("--data_root", type=str, default="input", help="Path to folder of images to caption")
parser.add_argument("--prompt", type=str, default="An image of", help="Prompt for generating caption")
parser.add_argument("--keep_prompt", action="store_true", default=False, help="will keep the prompt at the start of the caption when saved")
parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of tokens to generate")
parser.add_argument("--save_entities", action="store_true", default=False, help="Save coord box with entities to a separate .ent file")
parser.add_argument("--over_write", action="store_true", default=False, help="will overwrite txt and ent files if they exist")
args = parser.parse_args()
main(args)

View File

@ -1,8 +1,8 @@
diffusers[torch]>=0.18.0
diffusers[torch]>=0.21.4
ninja
numpy
omegaconf==2.2.3
protobuf==3.20.3
pyre-extensions==0.0.29
pytorch-lightning==1.9.2
transformers==4.29.2
transformers==4.35.0

View File

@ -4,8 +4,8 @@ echo should be in venv here
cd .
python -m pip install --upgrade pip
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url "https://download.pytorch.org/whl/cu118"
pip install -U transformers==4.29.2
pip install -U diffusers[torch]==0.18.0
pip install -U transformers==4.35.0
pip install -U diffusers[torch]==0.21.4
pip install pynvml==11.4.1
pip install -U pip install -U https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl
pip install ftfy==6.1.1