some cleanup/updates to caption stuff
This commit is contained in:
parent
273ce23c7c
commit
056de840d0
159
caption_cog.py
159
caption_cog.py
|
@ -36,7 +36,12 @@ from colorama import Fore, Style
|
|||
|
||||
from plugins.caption_plugins import load_prompt_alteration_plugin
|
||||
from utils.patch_cog import patch_cog
|
||||
from data.gen_utils import image_generator, SUPPORTED_EXT
|
||||
from data.generators import image_path_generator, SUPPORTED_EXT
|
||||
|
||||
try:
|
||||
from moai.load_moai import prepare_moai
|
||||
except ImportError:
|
||||
print("moai not found, skipping")
|
||||
|
||||
IMAGE_SIZE: int = 490
|
||||
PATCH_SIZE: int = 14
|
||||
|
@ -107,7 +112,7 @@ def save_params(args, gen_kwargs):
|
|||
with open(save_path, "w") as f:
|
||||
f.write(pretty_print)
|
||||
|
||||
def create_bnb_config(args):
|
||||
def create_bnb_config():
|
||||
return BitsAndBytesConfig(
|
||||
bnb_4bit_compute_dtype="float32",
|
||||
bnb_4bit_quant_type= "fp4",
|
||||
|
@ -121,58 +126,117 @@ def create_bnb_config(args):
|
|||
quant_method="bitsandbytes"
|
||||
)
|
||||
|
||||
class MoaiManager:
|
||||
def __init__(self, model_name: str):
|
||||
self.model_name = model_name
|
||||
self.moai_model = None
|
||||
self.moai_processor = None
|
||||
self.seg_model = None
|
||||
self.seg_processor = None
|
||||
self.od_model = None
|
||||
self.od_processor = None
|
||||
self.sgg_model = None
|
||||
self.ocr_model = None
|
||||
|
||||
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str='fp16'):
|
||||
moai_model, moai_processor, seg_model, seg_processor, od_model, od_processor, sgg_model, ocr_model \
|
||||
= prepare_moai(moai_path=self.model_name, bits=bits, grad_ckpt=grad_ckpt, lora=lora, dtype=dtype)
|
||||
self.moai_model = moai_model
|
||||
self.moai_processor = moai_processor
|
||||
self.seg_model = seg_model
|
||||
self.seg_processor = seg_processor
|
||||
self.od_model = od_model
|
||||
self.od_processor = od_processor
|
||||
self.sgg_model = sgg_model
|
||||
self.ocr_model = ocr_model
|
||||
|
||||
return moai_model, moai_processor
|
||||
|
||||
def get_inputs(self, image: Image.Image, prompt: str):
|
||||
moai_inputs = self.moai_model.demo_process(image=image,
|
||||
prompt=prompt,
|
||||
processor=self.moai_processor,
|
||||
seg_model=self.seg_model,
|
||||
seg_processor=self.seg_processor,
|
||||
od_model=self.od_model,
|
||||
od_processor=self.od_processor,
|
||||
sgg_model=self.sgg_model,
|
||||
ocr_model=self.ocr_model,
|
||||
device='cuda:0')
|
||||
return moai_inputs
|
||||
|
||||
def __call__(self, moai_inputs, do_sample=True, temperature=0.9, top_p=0.95, max_new_tokens=256, use_cache=True) -> Any:
|
||||
with torch.inference_mode():
|
||||
generate_ids = self.moai_model.generate(**moai_inputs, do_sample=do_sample, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, use_cache=use_cache)
|
||||
answer = self.moai_processor.batch_decode(generate_ids, skip_special_tokens=True)[0].split('[U')[0]
|
||||
return answer
|
||||
|
||||
class CogVLMManager:
|
||||
def __init__(self, model_name: str):
|
||||
self.model_name = model_name
|
||||
self.tokenizer = None
|
||||
self.model = None
|
||||
|
||||
def load_model(self):
|
||||
self.tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True,
|
||||
quantization_config=create_bnb_config()
|
||||
)
|
||||
return self.model, self.tokenizer
|
||||
|
||||
def get_inputs(self, prompt: str, history: List[Tuple[str, str]], images: List[Image.Image], starts_with: str):
|
||||
return build_conversation_input_ids(self.tokenizer, query=prompt, history=history, images=images, starts_with=starts_with)
|
||||
|
||||
def get_gen_kwargs(self, args):
|
||||
gen_kwargs = {
|
||||
"max_length": args.max_length,
|
||||
"do_sample": args.top_k is not None or args.top_p is not None or args.temp is not None or False,
|
||||
"length_penalty": args.length_penalty,
|
||||
"num_beams": args.num_beams,
|
||||
"temperature": args.temp,
|
||||
"top_k": args.top_k,
|
||||
"top_p": args.top_p,
|
||||
"repetition_penalty": args.repetition_penalty,
|
||||
"no_repeat_ngram_size": args.no_repeat_ngram_size,
|
||||
"min_new_tokens": args.min_new_tokens,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"length_penalty": args.length_penalty,
|
||||
}
|
||||
print(gen_kwargs)
|
||||
if args.max_new_tokens is not None:
|
||||
logging.info(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length")
|
||||
del gen_kwargs["max_length"]
|
||||
|
||||
if not gen_kwargs["do_sample"]:
|
||||
logging.info(f"** Using greedy sampling")
|
||||
del gen_kwargs["top_k"]
|
||||
del gen_kwargs["top_p"]
|
||||
del gen_kwargs["temperature"]
|
||||
else:
|
||||
logging.info(f"** Sampling enabled")
|
||||
return gen_kwargs
|
||||
|
||||
def model_manager_factory(model_name: str):
|
||||
if "moai" in model_name:
|
||||
return MoaiManager(model_name)
|
||||
else:
|
||||
return CogVLMManager(model_name)
|
||||
|
||||
def main(args):
|
||||
prompt_plugin_fn = load_prompt_alteration_plugin(args.prompt_plugin, args=args)
|
||||
model_manager = model_manager_factory(args.model)
|
||||
|
||||
bnb_config = create_bnb_config(args)
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
'THUDM/cogvlm-chat-hf',
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True, # gee hope they don't get hacked or have a bad internal actor
|
||||
#revision=... # no one is actually doing this
|
||||
#load_in_4bit=not args.disable_4bit,
|
||||
quantization_config=bnb_config,
|
||||
)
|
||||
|
||||
do_sample = args.top_k is not None or args.top_p is not None or args.temp is not None
|
||||
if do_sample:
|
||||
args.top_k = args.top_k or 50
|
||||
args.top_p = args.top_p or 1.0
|
||||
args.temp = args.temp or 1.0
|
||||
model, tokenizer = model_manager.load_model()
|
||||
|
||||
args.append = args.append or ""
|
||||
if len(args.append) > 0:
|
||||
args.append = " " + args.append.strip()
|
||||
|
||||
gen_kwargs = {
|
||||
"max_length": args.max_length,
|
||||
"do_sample": do_sample,
|
||||
"length_penalty": args.length_penalty,
|
||||
"num_beams": args.num_beams,
|
||||
"temperature": args.temp,
|
||||
"top_k": args.top_k,
|
||||
"top_p": args.top_p,
|
||||
"repetition_penalty": args.repetition_penalty,
|
||||
"no_repeat_ngram_size": args.no_repeat_ngram_size,
|
||||
"min_new_tokens": args.min_new_tokens,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"length_penalty": args.length_penalty,
|
||||
}
|
||||
|
||||
if args.max_new_tokens is not None:
|
||||
logging.info(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length")
|
||||
del gen_kwargs["max_length"]
|
||||
|
||||
if not do_sample:
|
||||
logging.info(f"** Using greedy sampling")
|
||||
del gen_kwargs["top_k"]
|
||||
del gen_kwargs["top_p"]
|
||||
del gen_kwargs["temperature"]
|
||||
else:
|
||||
logging.info(f"** Sampling enabled")
|
||||
gen_kwargs = model_manager.get_gen_kwargs(args)
|
||||
|
||||
force_words_ids = None
|
||||
if args.force_words is not None:
|
||||
|
@ -195,7 +259,7 @@ def main(args):
|
|||
|
||||
starts_with = args.starts_with.strip() if args.starts_with is not None else ""
|
||||
|
||||
for i, image_path in enumerate(image_generator(args.image_dir, do_recurse=not args.no_recurse)):
|
||||
for i, image_path in enumerate(image_path_generator(args.image_dir, do_recurse=not args.no_recurse)):
|
||||
candidate_caption_path = image_path.replace(os.path.splitext(image_path)[-1], ".txt")
|
||||
|
||||
if args.no_overwrite and os.path.exists(candidate_caption_path):
|
||||
|
@ -342,6 +406,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--starts_with", type=str, default=None, help="Force start words on the output caption.")
|
||||
argparser.add_argument("--remove_starts_with", action="store_true", help="Removes the starts_with words from the output caption.")
|
||||
argparser.add_argument("--append_log", action="store_true", help="Sets logging to append mode.")
|
||||
argparser.add_argument("--model", type=str, default="THUDM/cogvlm-chat-hf", help="Model to use for captioning.")
|
||||
args = argparser.parse_args()
|
||||
|
||||
configure_logging(args)
|
||||
|
|
|
@ -1,15 +0,0 @@
|
|||
import os
|
||||
from typing import Generator
|
||||
|
||||
SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
|
||||
|
||||
def image_generator(image_dir: str, do_recurse: bool = True) -> Generator[str, None, None]:
|
||||
if do_recurse:
|
||||
for root, dirs, files in os.walk(image_dir):
|
||||
for file in files:
|
||||
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
|
||||
yield os.path.join(root, file)
|
||||
else:
|
||||
for file in os.listdir(image_dir):
|
||||
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
|
||||
yield os.path.join(image_dir, file)
|
|
@ -0,0 +1,96 @@
|
|||
"""
|
||||
Copyright [2022-2024] Victor C Hall
|
||||
|
||||
Licensed under the GNU Affero General Public License;
|
||||
You may not use this code except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https://www.gnu.org/licenses/agpl-3.0.en.html
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Generator
|
||||
from data.image_train_item import ImageTrainItem, ImageCaption
|
||||
from PIL import Image, ImageOps
|
||||
import tarfile
|
||||
import logging
|
||||
|
||||
SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
|
||||
|
||||
class BucketBatchedGenerator(Generator[ImageTrainItem, None, None]):
|
||||
"""
|
||||
returns items in with the same aspect ratio in batches, for use with batching dataloaders
|
||||
"""
|
||||
def __init__(self, batch_size: int=1, generator: Generator[ImageTrainItem, None, None]=None):
|
||||
self.caption = batch_size
|
||||
self.cache = {}
|
||||
self.generator = generator
|
||||
|
||||
def __iter__(self):
|
||||
for item in self.generator:
|
||||
if item.target_wh:
|
||||
aspect_bucket_key = item.target_wh
|
||||
if aspect_bucket_key not in self.cache:
|
||||
self.cache[aspect_bucket_key] = []
|
||||
self.cache[aspect_bucket_key].append(item)
|
||||
if len(self.cache[aspect_bucket_key]) >= self.batch_size:
|
||||
for item in self.cache[aspect_bucket_key]:
|
||||
yield item
|
||||
self.cache[aspect_bucket_key] = []
|
||||
|
||||
# def image_train_item_generator_from_tar_pairs(image_dir: str, do_recurse: bool = True) -> Generator[ImageTrainItem, None, None]:
|
||||
# for root, dirs, files in os.walk(image_dir):
|
||||
# for file in files:
|
||||
# if file.endswith(".tar"):
|
||||
# tar_path = os.path.join(root, file)
|
||||
# with tarfile.open(tar_path, "r") as tar:
|
||||
# for tarinfo in tar:
|
||||
# if tarinfo.isfile() and any(tarinfo.name.endswith(ext) for ext in SUPPORTED_EXT):
|
||||
# try:
|
||||
# img = Image.open(tar.extractfile(tarinfo))
|
||||
# txt = tar.extractfile(tarinfo.name.replace(os.path.splitext(tarinfo.name)[-1], ".txt"))
|
||||
# caption = txt.read().decode("utf-8")
|
||||
# img_caption = ImageCaption(main_prompt=caption, rating=0, tags=[], tag_weights=[], max_target_length=256, use_weights=False)
|
||||
# img = ImageOps.exif_transpose(img)
|
||||
# iti = ImageTrainItem(img, img_caption)
|
||||
# except Exception as e:
|
||||
# logging.error(f"Failed to open {tarinfo.name}: {e}")
|
||||
# continue
|
||||
# yield iti
|
||||
|
||||
def image_train_item_generator_from_files(image_dir: str, do_recurse: bool = True) -> Generator[ImageTrainItem, None, None]:
|
||||
for img_path in image_path_generator(image_dir, do_recurse):
|
||||
try:
|
||||
img = Image.open(img_path)
|
||||
img = ImageOps.exif_transpose(img)
|
||||
except Exception as e:
|
||||
print(f"Failed to open {img_path}: {e}")
|
||||
continue
|
||||
# main_prompt: str, rating: float, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
|
||||
txt_cap_path = img_path.replace(os.path.splitext(img_path)[-1], ".txt")
|
||||
if os.path.exists(txt_cap_path):
|
||||
with open(txt_cap_path, "r") as f:
|
||||
caption = f.read()
|
||||
if not caption or len(caption) < 1:
|
||||
caption = os.path.basename(img_path)
|
||||
caption = caption.split("_")[0]
|
||||
image_caption = ImageCaption(main_prompt=caption, rating=0, tags=[], tag_weights=[], max_target_length=128, use_weights=False)
|
||||
iti = ImageTrainItem(img)
|
||||
yield iti
|
||||
|
||||
def image_path_generator(image_dir: str, do_recurse: bool = True) -> Generator[str, None, None]:
|
||||
if do_recurse:
|
||||
for root, dirs, files in os.walk(image_dir):
|
||||
for file in files:
|
||||
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
|
||||
yield os.path.join(root, file)
|
||||
else:
|
||||
for file in os.listdir(image_dir):
|
||||
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
|
||||
yield os.path.join(image_dir, file)
|
|
@ -22,7 +22,7 @@ class TestSub(TestBase):
|
|||
def __repr__(self) -> str:
|
||||
return f"TestSub: {self.a}, {self.b}"
|
||||
|
||||
class PromptIdentityPlugin():
|
||||
class PromptIdentityBase():
|
||||
"""
|
||||
Base class for prompt alternation plugins, useful for captioning, etc.
|
||||
"""
|
||||
|
@ -68,7 +68,7 @@ class PromptIdentityPlugin():
|
|||
prompt = f"Hint: {hint}\n{prompt}"
|
||||
return prompt
|
||||
|
||||
class HintFromFilename(PromptIdentityPlugin):
|
||||
class HintFromFilename(PromptIdentityBase):
|
||||
def __init__(self, args:Namespace=None):
|
||||
super().__init__(key="hint_from_filename",
|
||||
description="Add a hint to the prompt using the filename of the image (without extension)",
|
||||
|
@ -81,7 +81,7 @@ class HintFromFilename(PromptIdentityPlugin):
|
|||
prompt = self._add_hint_to_prompt(filename, prompt)
|
||||
return prompt
|
||||
|
||||
class RemoveUsingCSV(PromptIdentityPlugin):
|
||||
class RemoveUsingCSV(PromptIdentityBase):
|
||||
def __init__(self, args:Namespace=None):
|
||||
super().__init__(key="remove_using_csv",
|
||||
description="Removes whole word matches of the csv passed in from the prompt",
|
||||
|
@ -111,7 +111,7 @@ class RemoveUsingCSV(PromptIdentityPlugin):
|
|||
prompt = self._filter_logic(prompt, [word])
|
||||
return prompt
|
||||
|
||||
class HintFromLeafDirectory(PromptIdentityPlugin):
|
||||
class HintFromLeafDirectory(PromptIdentityBase):
|
||||
def __init__(self, args:Namespace=None):
|
||||
super().__init__(key="from_leaf_directory",
|
||||
description="Adds a hint to the prompt using the leaf directory name (last folder in path)",
|
||||
|
@ -130,13 +130,13 @@ class MetadataProvider():
|
|||
self._datadict = {}
|
||||
|
||||
def _from_metadata(self, args) -> dict:
|
||||
image_path = args.get("image_path", "")
|
||||
image_path = args.image_path
|
||||
prompt = args.get("prompt", "")
|
||||
metadata = self._get_metadata_dict(image_path)
|
||||
return f"metadata: {metadata}\n{prompt}"
|
||||
|
||||
def _get_metadata_dict(self, metadata_path: str) -> dict:
|
||||
if not self.loaded and not metadata_path in self.cache:
|
||||
if not metadata_path in self._datadict:
|
||||
metadata_dirname = os.path.dirname(metadata_path)
|
||||
if not os.path.exists(metadata_path):
|
||||
logging.warning(f" metadata.json not found in {metadata_dirname}, ignoring{Style.RESET_ALL}")
|
||||
|
@ -145,12 +145,12 @@ class MetadataProvider():
|
|||
metadata = json.load(f)
|
||||
self._datadict[metadata_path] = metadata
|
||||
|
||||
return self.dict[metadata_path]
|
||||
return self._datadict[metadata_path]
|
||||
|
||||
class FromFolderMetadataJson(PromptIdentityPlugin):
|
||||
class FromFolderMetadataJson(PromptIdentityBase):
|
||||
def __init__(self, args:Namespace=None):
|
||||
super().__init__(key="from_folder_metadata",
|
||||
description="Looks for metadata.json in the folder of the images",
|
||||
description="Looks for metadata.json in the folder of the images and prefixes it to the prompt",
|
||||
fn=self._from_metadata_json,
|
||||
args=args)
|
||||
self.metadata_provider = MetadataProvider()
|
||||
|
@ -159,11 +159,12 @@ class FromFolderMetadataJson(PromptIdentityPlugin):
|
|||
image_path = args.image_path
|
||||
image_dir = os.path.dirname(image_path)
|
||||
metadata_json_path = os.path.join(image_dir, "metadata.json")
|
||||
self.metadata_provider._get_metadata_dict(metadata_json_path)
|
||||
metadata = self.metadata_provider._get_metadata_dict(metadata_json_path)
|
||||
metadata = json.dumps(metadata, indent=2)
|
||||
prompt = self._add_hint_to_prompt(f"metadata: {metadata}", args.prompt)
|
||||
return prompt
|
||||
|
||||
return ""
|
||||
|
||||
class TagsFromFolderMetadataJson(PromptIdentityPlugin):
|
||||
class TagsFromFolderMetadataJson(PromptIdentityBase):
|
||||
def __init__(self, args:Namespace=None):
|
||||
self.cache = {}
|
||||
super().__init__(key = "tags_from_metadata_json",
|
||||
|
@ -185,7 +186,7 @@ class TagsFromFolderMetadataJson(PromptIdentityPlugin):
|
|||
return self._add_hint_to_prompt(f"tags: {tags}", prompt)
|
||||
return prompt
|
||||
|
||||
class TitleAndTagsFromFolderImageJson(PromptIdentityPlugin):
|
||||
class TitleAndTagsFromImageJson(PromptIdentityBase):
|
||||
def __init__(self, args:Namespace=None):
|
||||
super().__init__(key="title_and_tags_from_image_json",
|
||||
description="Adds title and tags hint from metadata.json (in the samefolder as the image) to the prompt",
|
||||
|
@ -218,7 +219,7 @@ class TitleAndTagsFromFolderImageJson(PromptIdentityPlugin):
|
|||
logging.debug(f" {self.key}: prompt after: {prompt}")
|
||||
return prompt
|
||||
|
||||
class TitleAndTagsFromFolderMetadataJson(PromptIdentityPlugin):
|
||||
class TitleAndTagsFromFolderMetadataJson(PromptIdentityBase):
|
||||
def __init__(self, args:Namespace=None):
|
||||
self.cache = {}
|
||||
super().__init__(key="title_and_tags_from_metadata_json",
|
||||
|
@ -254,7 +255,7 @@ class TitleAndTagsFromFolderMetadataJson(PromptIdentityPlugin):
|
|||
logging.debug(f" {self.key}: prompt after: {prompt}")
|
||||
return prompt
|
||||
|
||||
class TitleAndTagsFromGlobalMetadataJson(PromptIdentityPlugin):
|
||||
class TitleAndTagsFromGlobalMetadataJson(PromptIdentityBase):
|
||||
"""
|
||||
Adds title and tags hint from global metadata json given by '--metadatafilename'
|
||||
Note: you could just put your metadata in the prompt instead of using this plugin, but perhaps useful?
|
||||
|
@ -316,8 +317,8 @@ def get_prompt_alteration_plugin_list() -> list:
|
|||
|
||||
if isinstance(attribute, type) \
|
||||
and attribute.__module__ == module.__name__ \
|
||||
and is_subclass_of_subclass(attribute, PromptIdentityPlugin, recursion_depth=5) \
|
||||
and attribute is not PromptIdentityPlugin:
|
||||
and is_subclass_of_subclass(attribute, PromptIdentityBase, recursion_depth=5) \
|
||||
and attribute is not PromptIdentityBase:
|
||||
|
||||
plugins.append(attribute)
|
||||
#print(f"done checking plugins_module_name: {plugins_module_name}")
|
||||
|
@ -337,4 +338,4 @@ def load_prompt_alteration_plugin(plugin_key: str, args) -> callable:
|
|||
raise ValueError(f"plugin_key: {plugin_key} not found in prompt_alteration_plugins")
|
||||
else:
|
||||
logging.info(f"No plugin specified")
|
||||
return PromptIdentityPlugin(args=args)
|
||||
return PromptIdentityBase(args=args)
|
||||
|
|
Loading…
Reference in New Issue