diff --git a/configs/stable-diffusion/v1-finetune-danbooru-8gpu.yaml b/configs/stable-diffusion/v1-finetune-danbooru-8gpu.yaml index ff3b871..3414b71 100644 --- a/configs/stable-diffusion/v1-finetune-danbooru-8gpu.yaml +++ b/configs/stable-diffusion/v1-finetune-danbooru-8gpu.yaml @@ -78,12 +78,13 @@ data: resize: false flip_p: 0.5 image_key: "image" - copyright_rate: 0.9 - character_rate: 0.9 - general_rate: 0.9 - artist_rate: 0.9 + copyright_rate: 1.0 + character_rate: 1.0 + general_rate: 1.0 + artist_rate: 1.0 normalize: true caption_shuffle: true + random_order: true lightning: modelcheckpoint: diff --git a/ldm/data/localdanbooru.py b/ldm/data/localdanbooru.py index 7624646..31b0c3a 100644 --- a/ldm/data/localdanbooru.py +++ b/ldm/data/localdanbooru.py @@ -36,7 +36,7 @@ def resize_image(image: Image, max_size=(768,768)): return res class CaptionProcessor(object): - def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms, max_size, resize): + def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms, max_size, resize, random_order): self.copyright_rate = copyright_rate self.character_rate = character_rate self.general_rate = general_rate @@ -46,6 +46,7 @@ class CaptionProcessor(object): self.transforms = transforms self.max_size = max_size self.resize = resize + self.random_order = random_order def clean(self, text: str): text = ' '.join(set([i.lstrip('_').rstrip('_') for i in re.sub(r'\([^)]*\)', '', text).split(' ')])).lstrip().rstrip() @@ -71,11 +72,19 @@ class CaptionProcessor(object): def __call__(self, sample): # preprocess caption caption_data = json.loads(sample['caption']) - character = self.get_key(caption_data, 'tag_string_character', True, self.character_rate, False, True) - copyright = self.get_key(caption_data, 'tag_string_copyright', True, self.copyright_rate, True, True) - artist = self.get_key(caption_data, 'tag_string_artist', True, self.artist_rate, True, True) - general = self.get_key(caption_data, 'tag_string_general', True, self.general_rate, True, False) - sample['caption'] = f'{character}{copyright}{artist}{general}'.lstrip().rstrip(',') + if not self.random_order: + character = self.get_key(caption_data, 'tag_string_character', True, self.character_rate, False, True) + copyright = self.get_key(caption_data, 'tag_string_copyright', True, self.copyright_rate, True, True) + artist = self.get_key(caption_data, 'tag_string_artist', True, self.artist_rate, True, True) + general = self.get_key(caption_data, 'tag_string_general', True, self.general_rate, True, False) + tag_str = f'{character}{copyright}{artist}{general}'.lstrip().rstrip(',') + else: + character = self.get_key(caption_data, 'tag_string_character', False, self.character_rate, False) + copyright = self.get_key(caption_data, 'tag_string_copyright', False, self.copyright_rate, True, False) + artist = self.get_key(caption_data, 'tag_string_artist', False, self.artist_rate, True, False) + general = self.get_key(caption_data, 'tag_string_general', False, self.general_rate, True, False) + tag_str = self.clean(f'{character}{copyright}{artist}{general}').lstrip().rstrip(' ') + sample['caption'] = tag_str # preprocess image image = sample['image'] @@ -150,7 +159,7 @@ class DanbooruWebDataModuleFromConfig(pl.LightningDataModule): transform_dict = {} transform_dict.update({self.image_key: image_transforms}) - postprocess = CaptionProcessor(copyright_rate=self.copyright_rate, character_rate=self.character_rate, general_rate=self.general_rate, artist_rate=self.artist_rate, normalize=self.normalize, caption_shuffle=self.caption_shuffle, transforms=image_transforms, max_size=self.max_size, resize=self.resize) + postprocess = CaptionProcessor(copyright_rate=self.copyright_rate, character_rate=self.character_rate, general_rate=self.general_rate, artist_rate=self.artist_rate, normalize=self.normalize, caption_shuffle=self.caption_shuffle, transforms=image_transforms, max_size=self.max_size, resize=self.resize, random_order=self.random_order) tars = os.path.join(self.tar_base)