Add random tag order during training
This commit is contained in:
parent
b080f33115
commit
2e69358d50
|
@ -78,12 +78,13 @@ data:
|
|||
resize: false
|
||||
flip_p: 0.5
|
||||
image_key: "image"
|
||||
copyright_rate: 0.9
|
||||
character_rate: 0.9
|
||||
general_rate: 0.9
|
||||
artist_rate: 0.9
|
||||
copyright_rate: 1.0
|
||||
character_rate: 1.0
|
||||
general_rate: 1.0
|
||||
artist_rate: 1.0
|
||||
normalize: true
|
||||
caption_shuffle: true
|
||||
random_order: true
|
||||
|
||||
lightning:
|
||||
modelcheckpoint:
|
||||
|
|
|
@ -36,7 +36,7 @@ def resize_image(image: Image, max_size=(768,768)):
|
|||
return res
|
||||
|
||||
class CaptionProcessor(object):
|
||||
def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms, max_size, resize):
|
||||
def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms, max_size, resize, random_order):
|
||||
self.copyright_rate = copyright_rate
|
||||
self.character_rate = character_rate
|
||||
self.general_rate = general_rate
|
||||
|
@ -46,6 +46,7 @@ class CaptionProcessor(object):
|
|||
self.transforms = transforms
|
||||
self.max_size = max_size
|
||||
self.resize = resize
|
||||
self.random_order = random_order
|
||||
|
||||
def clean(self, text: str):
|
||||
text = ' '.join(set([i.lstrip('_').rstrip('_') for i in re.sub(r'\([^)]*\)', '', text).split(' ')])).lstrip().rstrip()
|
||||
|
@ -71,11 +72,19 @@ class CaptionProcessor(object):
|
|||
def __call__(self, sample):
|
||||
# preprocess caption
|
||||
caption_data = json.loads(sample['caption'])
|
||||
character = self.get_key(caption_data, 'tag_string_character', True, self.character_rate, False, True)
|
||||
copyright = self.get_key(caption_data, 'tag_string_copyright', True, self.copyright_rate, True, True)
|
||||
artist = self.get_key(caption_data, 'tag_string_artist', True, self.artist_rate, True, True)
|
||||
general = self.get_key(caption_data, 'tag_string_general', True, self.general_rate, True, False)
|
||||
sample['caption'] = f'{character}{copyright}{artist}{general}'.lstrip().rstrip(',')
|
||||
if not self.random_order:
|
||||
character = self.get_key(caption_data, 'tag_string_character', True, self.character_rate, False, True)
|
||||
copyright = self.get_key(caption_data, 'tag_string_copyright', True, self.copyright_rate, True, True)
|
||||
artist = self.get_key(caption_data, 'tag_string_artist', True, self.artist_rate, True, True)
|
||||
general = self.get_key(caption_data, 'tag_string_general', True, self.general_rate, True, False)
|
||||
tag_str = f'{character}{copyright}{artist}{general}'.lstrip().rstrip(',')
|
||||
else:
|
||||
character = self.get_key(caption_data, 'tag_string_character', False, self.character_rate, False)
|
||||
copyright = self.get_key(caption_data, 'tag_string_copyright', False, self.copyright_rate, True, False)
|
||||
artist = self.get_key(caption_data, 'tag_string_artist', False, self.artist_rate, True, False)
|
||||
general = self.get_key(caption_data, 'tag_string_general', False, self.general_rate, True, False)
|
||||
tag_str = self.clean(f'{character}{copyright}{artist}{general}').lstrip().rstrip(' ')
|
||||
sample['caption'] = tag_str
|
||||
|
||||
# preprocess image
|
||||
image = sample['image']
|
||||
|
@ -150,7 +159,7 @@ class DanbooruWebDataModuleFromConfig(pl.LightningDataModule):
|
|||
transform_dict = {}
|
||||
transform_dict.update({self.image_key: image_transforms})
|
||||
|
||||
postprocess = CaptionProcessor(copyright_rate=self.copyright_rate, character_rate=self.character_rate, general_rate=self.general_rate, artist_rate=self.artist_rate, normalize=self.normalize, caption_shuffle=self.caption_shuffle, transforms=image_transforms, max_size=self.max_size, resize=self.resize)
|
||||
postprocess = CaptionProcessor(copyright_rate=self.copyright_rate, character_rate=self.character_rate, general_rate=self.general_rate, artist_rate=self.artist_rate, normalize=self.normalize, caption_shuffle=self.caption_shuffle, transforms=image_transforms, max_size=self.max_size, resize=self.resize, random_order=self.random_order)
|
||||
|
||||
|
||||
tars = os.path.join(self.tar_base)
|
||||
|
|
Loading…
Reference in New Issue