diff --git a/caption_cog.py b/caption_cog.py index c3867a5..b5feb9b 100644 --- a/caption_cog.py +++ b/caption_cog.py @@ -23,6 +23,7 @@ from typing import Generator import torch from PIL import Image +import PIL.ImageOps as ImageOps from pynvml import * from transformers import AutoModelForCausalLM, LlamaTokenizer @@ -53,19 +54,29 @@ def main(args): load_in_4bit=not args.disable_4bit, ) - do_sample = args.num_beams > 1 + do_sample = args.top_k is not None or args.top_p is not None or args.temp is not None + if do_sample: + args.top_k = args.top_k or 50 + args.top_p = args.top_p or 1.0 + args.temp = args.temp or 1.0 + + args.append = args.append or "" + if len(args.append) > 0 and not args.append.startswith(" "): + args.append = " " + args.append gen_kwargs = { - "max_length": args.max_length, - "do_sample": do_sample, - "num_beams": args.num_beams, - "temperature": args.temp, - "top_k": args.top_k, - "top_p": args.top_p, - "repetition_penalty": args.repetition_penalty, - "no_repeat_ngram_size": args.no_repeat_ngram_size, + "max_length": args.max_length, + "do_sample": do_sample, + "length_penalty": args.length_penalty, + "num_beams": args.num_beams, + "temperature": args.temp, + "top_k": args.top_k, + "top_p": args.top_p, + "repetition_penalty": args.repetition_penalty, + "no_repeat_ngram_size": args.no_repeat_ngram_size, "min_new_tokens": args.min_new_tokens, "max_new_tokens": args.max_new_tokens, + "length_penalty": args.length_penalty, } if args.max_new_tokens is not None: @@ -73,9 +84,12 @@ def main(args): del gen_kwargs["max_length"] if not do_sample: - print(f"** num_beams set to 1, sampling is disabled") + print(f"** Using greedy sampling") del gen_kwargs["top_k"] del gen_kwargs["top_p"] + del gen_kwargs["temperature"] + else: + print(f"** Sampling enabled") force_words_ids = None if args.force_words is not None: @@ -103,6 +117,14 @@ def main(args): start_time = time.time() image = Image.open(image_path) + + try: + image = image.convert('RGB') + image = ImageOps.exif_transpose(image) + except Exception as e: + print(f"Non-fatal error processing {image_path}: {e}") + continue + inputs = model.build_conversation_input_ids(tokenizer, query=args.prompt, history=[], images=[image]) # chat mode inputs = { 'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'), @@ -115,6 +137,7 @@ def main(args): outputs = model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids) outputs_without_prompt = outputs[:, inputs['input_ids'].shape[1]:] caption = tokenizer.decode(outputs_without_prompt[0], skip_special_tokens=True) + caption += args.append with open(candidate_caption_path, "w") as f: f.write(caption) @@ -137,40 +160,64 @@ EXAMPLES = """ex. Basic example: python caption_cog.py --image_dir /mnt/mydata/kyrie/ --prompt 'Describe this image in detail, including the subject matter and medium of the artwork.' +Use probabilistic sampling by using any of top_k, top_p, or temp: + python caption_cog.py --image_dir \"c:/users/chadley/my documents/pictures\" --prompt \"What is this?\" --top_p 0.9 + Use beam search and probabilistic sampling: - python caption_cog.py --image_dir \"c:/users/chadley/my documents/pictures\" --prompt 'Write a description.' --max_new_tokens 75 --num_beams 4 --temp 0.9 --top_k 3 --top_p 0.9 --repetition_penalty 1.0 --no_repeat_ngram_size 0 --min_new_tokens 5\n + python caption_cog.py --image_dir \"c:/users/chadley/my documents/pictures\" --prompt \"Write a description.\" --max_new_tokens 75 --num_beams 4 --temp 0.9 --top_k 3 --top_p 0.9 --repetition_penalty 1.0 --no_repeat_ngram_size 0 --min_new_tokens 5 Force "cat" and "dog" and disallow the word "depicts": python caption_cog.py --image_dir /mnt/lcl/nvme/mldata/test --num_beams 3 --force_words "cat,dog" --bad_words "depicts" +Use a lot of beams and try to control the length with length_penalty: + python caption_cog.py --image_dir /mnt/lcl/nvme/mldata/test --num_beams 8 --length_penalty 0.8 --prompt "Write a single sentence description." + Notes: - numbeams > 1 enables probabilistic sampling, which is required for the temperature, top_k, top_p parameters to function. More beams is more opinions on the next token, but slower and more VRAM intensive as it is done in batch mode. - Increasing num_beams has a substantial impact on VRAM and speed. Ex beams =1 ~13.3gb, beams = 4 ~ 23.7GB - Speed is linearly proportional to num_beams, so 4 beams is 4x slower than 1 beam. - Max_length and max_new_tokens are mutually exclusive. If max_new_tokens is set, max_length is ignored. + 1. Setting top_k, top_p, or temp enables probabilistic sampling (aka "do_sample"), otherwise greedy sampling is used. + a. num_beams 1 and do_sample false uses "greedy decoding" + b. num_beams 1 and do_sample true uses "multinomial sampling" + c. num_beams > 1 and do_sample true uses "beam-search multinomial sampling" + d. num_beams > 1 and do_sample false uses "beam-search decoding" + 2. Max_length and max_new_tokens are mutually exclusive. If max_new_tokens is set, max_length is ignored. Default is max_length 2048 if nothing set. + Using Max may abruptly end caption, consider modifying prompt or use length_penalty instead. + +Find more info on the Huggingface Transformers documentation: https://huggingface.co/docs/transformers/main_classes/text_generation +Parameters definitions and use map directly to their API. """ -DESCRIPTION = f"** {Fore.LIGHTBLUE_EX}CogVLM captioning script{Style.RESET_ALL} **\n" +DESCRIPTION = f"** {Fore.LIGHTBLUE_EX}CogVLM captioning script{Style.RESET_ALL} **\n Use --help for usage." if __name__ == "__main__": - argparser = argparse.ArgumentParser(description=DESCRIPTION, epilog=EXAMPLES) + argparser = argparse.ArgumentParser() argparser.add_argument("--disable_4bit", action="store_true", help="Disables 4bit inference for compatibility or experimentation. Bad for VRAM, fallback is bf16.") - argparser.add_argument("--temp", type=float, default=1.0, help="Temperature for sampling") - argparser.add_argument("--num_beams", type=int, default=2, help="Number of beams for sampling, see notes.") - argparser.add_argument("--top_k", type=int, default=0, help="Top-k, filter k highest probability tokens before sampling") - argparser.add_argument("--top_p", type=float, default=1.0, help="Top-p, selects from top tokens with cumulative probability >= p") + argparser.add_argument("--temp", type=float, default=None, help="Temperature for sampling") + argparser.add_argument("--num_beams", type=int, default=2, help="Number of beams for beam search, default 1 (off)") + argparser.add_argument("--top_k", type=int, default=None, help="Top-k, filter k highest probability tokens before sampling") + argparser.add_argument("--top_p", type=float, default=None, help="Top-p, for sampling, selects from top tokens with cumulative probability >= p") argparser.add_argument("--repetition_penalty", type=float, default=1.0, help="Repetition penalty") argparser.add_argument("--no_repeat_ngram_size", type=int, default=0, help="No repetition n-gram size") argparser.add_argument("--min_new_tokens", type=int, default=5, help="Minimum number of tokens in returned caption.") argparser.add_argument("--max_new_tokens", type=int, default=None, help="Maximum number of tokens in returned caption.") argparser.add_argument("--max_length", type=int, default=2048, help="Alternate to max_new_tokens, limits context.") - argparser.add_argument("--prompt", type=str, default="Describe this image.", help="Prompt that will guide captioning") + argparser.add_argument("--length_penalty", type=float, default=1.0, help="Length penalty, lower values encourage shorter captions.") + argparser.add_argument("--prompt", type=str, default="Write a description.", help="Prompt that will guide captioning") argparser.add_argument("--image_dir", type=str, default=None, help="Path to folder of images to caption") argparser.add_argument("--no_overwrite", action="store_true", help="Skips captioning images that already have a caption file.") argparser.add_argument("--force_words", type=str, default=None, help="Forces the model to include these words in the caption, use CSV format.") argparser.add_argument("--bad_words", type=str, default=None, help="Words that will not be allowed, use CSV format.") + argparser.add_argument("--append", type=str, default=None, help="Extra string to append to all captions. ex. 'painted by John Doe'") args = argparser.parse_args() + print(DESCRIPTION) + print(EXAMPLES) + + if args.top_k is not None or args.top_p is not None or args.temp is not None: + print(f"** Sampling enabled.") + args.sampling = True + args.top_k = args.top_k or 50 + args.top_p = args.top_p or 1.0 + args.temp = args.temp or 1.0 + print(DESCRIPTION) print(EXAMPLES) if args.image_dir is None: