diff --git a/caption_kosmos2.py b/caption_kosmos2.py new file mode 100644 index 0000000..c782a92 --- /dev/null +++ b/caption_kosmos2.py @@ -0,0 +1,99 @@ +""" +Copyright [2022-2023] Victor C Hall + +Licensed under the GNU Affero General Public License; +You may not use this code except in compliance with the License. +You may obtain a copy of the License at + + https://www.gnu.org/licenses/agpl-3.0.en.html + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import io +import argparse +import time + +from PIL import Image +from pynvml import * +from transformers import AutoProcessor, AutoModelForVision2Seq + + +GROUNDING = "" + +SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"] + +def get_gpu_memory_map(): + nvmlInit() + handle = nvmlDeviceGetHandleByIndex(0) + info = nvmlDeviceGetMemoryInfo(handle) + nvmlShutdown() + return info.used/1024/1024 + +def remove_starting_string(a, b): + if b.startswith(a): + return b[len(a):] # Remove string A from the beginning of string B + else: + return b + +def main(args): + for root, dirs, files in os.walk(args.data_root): + for file in files: + #get file extension + ext = os.path.splitext(file)[1] + if ext.lower() in SUPPORTED_EXT: + start_time = time.time() + + full_file_path = os.path.join(root, file) + image = Image.open(full_file_path) + model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224") + processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224") + + full_file_path = os.path.join(root, file) + image = Image.open(full_file_path) + + inputs = processor(text=GROUNDING+args.prompt, images=image, return_tensors="pt") + + generated_ids = model.generate( + pixel_values=inputs["pixel_values"], + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + image_embeds=None, + image_embeds_position_mask=inputs["image_embeds_position_mask"], + use_cache=True, + max_new_tokens=args.max_new_tokens, + ) + + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + processed_text, entities = processor.post_process_generation(generated_text) # remove remaining special tokens to get just the caption and entities + + if not args.keep_prompt: + processed_text = remove_starting_string(args.prompt, processed_text) + + print(f"File: {image}, Generated caption: {processed_text}") + + name = os.path.splitext(full_file_path)[0] + if not os.path.exists(f"{name}.txt") or args.over_write: + with open(f"{name}.txt", "w") as f: + f.write(processed_text) + + if args.save_entities and (not os.path.exists(f"{name}.ent") or args.over_write): + with open(f"{name}.ent", "w") as entities_file: + entities_file.write(entities) + +if __name__ == "__main__": + print("Kosmos-2 captioning script") + parser = argparse.ArgumentParser() + parser.add_argument("--data_root", type=str, default="input", help="Path to folder of images to caption") + parser.add_argument("--prompt", type=str, default="An image of", help="Prompt for generating caption") + parser.add_argument("--keep_prompt", action="store_true", default=False, help="will keep the prompt at the start of the caption when saved") + parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of tokens to generate") + parser.add_argument("--save_entities", action="store_true", default=False, help="Save coord box with entities to a separate .ent file") + parser.add_argument("--over_write", action="store_true", default=False, help="will overwrite txt and ent files if they exist") + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/docker/requirements-build.txt b/docker/requirements-build.txt index 64feddc..8670b40 100644 --- a/docker/requirements-build.txt +++ b/docker/requirements-build.txt @@ -1,8 +1,8 @@ -diffusers[torch]>=0.18.0 +diffusers[torch]>=0.21.4 ninja numpy omegaconf==2.2.3 protobuf==3.20.3 pyre-extensions==0.0.29 pytorch-lightning==1.9.2 -transformers==4.29.2 +transformers==4.35.0 diff --git a/windows_setup.cmd b/windows_setup.cmd index e81d9ff..58237bc 100644 --- a/windows_setup.cmd +++ b/windows_setup.cmd @@ -4,8 +4,8 @@ echo should be in venv here cd . python -m pip install --upgrade pip pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url "https://download.pytorch.org/whl/cu118" -pip install -U transformers==4.29.2 -pip install -U diffusers[torch]==0.18.0 +pip install -U transformers==4.35.0 +pip install -U diffusers[torch]==0.21.4 pip install pynvml==11.4.1 pip install -U pip install -U https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl pip install ftfy==6.1.1