Add random tag order during training

This commit is contained in:
harubaru 2022-09-21 04:57:17 -07:00
parent b080f33115
commit 2e69358d50
2 changed files with 21 additions and 11 deletions

View File

@ -78,12 +78,13 @@ data:
resize: false resize: false
flip_p: 0.5 flip_p: 0.5
image_key: "image" image_key: "image"
copyright_rate: 0.9 copyright_rate: 1.0
character_rate: 0.9 character_rate: 1.0
general_rate: 0.9 general_rate: 1.0
artist_rate: 0.9 artist_rate: 1.0
normalize: true normalize: true
caption_shuffle: true caption_shuffle: true
random_order: true
lightning: lightning:
modelcheckpoint: modelcheckpoint:

View File

@ -36,7 +36,7 @@ def resize_image(image: Image, max_size=(768,768)):
return res return res
class CaptionProcessor(object): class CaptionProcessor(object):
def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms, max_size, resize): def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms, max_size, resize, random_order):
self.copyright_rate = copyright_rate self.copyright_rate = copyright_rate
self.character_rate = character_rate self.character_rate = character_rate
self.general_rate = general_rate self.general_rate = general_rate
@ -46,6 +46,7 @@ class CaptionProcessor(object):
self.transforms = transforms self.transforms = transforms
self.max_size = max_size self.max_size = max_size
self.resize = resize self.resize = resize
self.random_order = random_order
def clean(self, text: str): def clean(self, text: str):
text = ' '.join(set([i.lstrip('_').rstrip('_') for i in re.sub(r'\([^)]*\)', '', text).split(' ')])).lstrip().rstrip() text = ' '.join(set([i.lstrip('_').rstrip('_') for i in re.sub(r'\([^)]*\)', '', text).split(' ')])).lstrip().rstrip()
@ -71,11 +72,19 @@ class CaptionProcessor(object):
def __call__(self, sample): def __call__(self, sample):
# preprocess caption # preprocess caption
caption_data = json.loads(sample['caption']) caption_data = json.loads(sample['caption'])
if not self.random_order:
character = self.get_key(caption_data, 'tag_string_character', True, self.character_rate, False, True) character = self.get_key(caption_data, 'tag_string_character', True, self.character_rate, False, True)
copyright = self.get_key(caption_data, 'tag_string_copyright', True, self.copyright_rate, True, True) copyright = self.get_key(caption_data, 'tag_string_copyright', True, self.copyright_rate, True, True)
artist = self.get_key(caption_data, 'tag_string_artist', True, self.artist_rate, True, True) artist = self.get_key(caption_data, 'tag_string_artist', True, self.artist_rate, True, True)
general = self.get_key(caption_data, 'tag_string_general', True, self.general_rate, True, False) general = self.get_key(caption_data, 'tag_string_general', True, self.general_rate, True, False)
sample['caption'] = f'{character}{copyright}{artist}{general}'.lstrip().rstrip(',') tag_str = f'{character}{copyright}{artist}{general}'.lstrip().rstrip(',')
else:
character = self.get_key(caption_data, 'tag_string_character', False, self.character_rate, False)
copyright = self.get_key(caption_data, 'tag_string_copyright', False, self.copyright_rate, True, False)
artist = self.get_key(caption_data, 'tag_string_artist', False, self.artist_rate, True, False)
general = self.get_key(caption_data, 'tag_string_general', False, self.general_rate, True, False)
tag_str = self.clean(f'{character}{copyright}{artist}{general}').lstrip().rstrip(' ')
sample['caption'] = tag_str
# preprocess image # preprocess image
image = sample['image'] image = sample['image']
@ -150,7 +159,7 @@ class DanbooruWebDataModuleFromConfig(pl.LightningDataModule):
transform_dict = {} transform_dict = {}
transform_dict.update({self.image_key: image_transforms}) transform_dict.update({self.image_key: image_transforms})
postprocess = CaptionProcessor(copyright_rate=self.copyright_rate, character_rate=self.character_rate, general_rate=self.general_rate, artist_rate=self.artist_rate, normalize=self.normalize, caption_shuffle=self.caption_shuffle, transforms=image_transforms, max_size=self.max_size, resize=self.resize) postprocess = CaptionProcessor(copyright_rate=self.copyright_rate, character_rate=self.character_rate, general_rate=self.general_rate, artist_rate=self.artist_rate, normalize=self.normalize, caption_shuffle=self.caption_shuffle, transforms=image_transforms, max_size=self.max_size, resize=self.resize, random_order=self.random_order)
tars = os.path.join(self.tar_base) tars = os.path.join(self.tar_base)