Add random tag order during training
This commit is contained in:
parent
b080f33115
commit
2e69358d50
|
@ -78,12 +78,13 @@ data:
|
||||||
resize: false
|
resize: false
|
||||||
flip_p: 0.5
|
flip_p: 0.5
|
||||||
image_key: "image"
|
image_key: "image"
|
||||||
copyright_rate: 0.9
|
copyright_rate: 1.0
|
||||||
character_rate: 0.9
|
character_rate: 1.0
|
||||||
general_rate: 0.9
|
general_rate: 1.0
|
||||||
artist_rate: 0.9
|
artist_rate: 1.0
|
||||||
normalize: true
|
normalize: true
|
||||||
caption_shuffle: true
|
caption_shuffle: true
|
||||||
|
random_order: true
|
||||||
|
|
||||||
lightning:
|
lightning:
|
||||||
modelcheckpoint:
|
modelcheckpoint:
|
||||||
|
|
|
@ -36,7 +36,7 @@ def resize_image(image: Image, max_size=(768,768)):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
class CaptionProcessor(object):
|
class CaptionProcessor(object):
|
||||||
def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms, max_size, resize):
|
def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms, max_size, resize, random_order):
|
||||||
self.copyright_rate = copyright_rate
|
self.copyright_rate = copyright_rate
|
||||||
self.character_rate = character_rate
|
self.character_rate = character_rate
|
||||||
self.general_rate = general_rate
|
self.general_rate = general_rate
|
||||||
|
@ -46,6 +46,7 @@ class CaptionProcessor(object):
|
||||||
self.transforms = transforms
|
self.transforms = transforms
|
||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
self.resize = resize
|
self.resize = resize
|
||||||
|
self.random_order = random_order
|
||||||
|
|
||||||
def clean(self, text: str):
|
def clean(self, text: str):
|
||||||
text = ' '.join(set([i.lstrip('_').rstrip('_') for i in re.sub(r'\([^)]*\)', '', text).split(' ')])).lstrip().rstrip()
|
text = ' '.join(set([i.lstrip('_').rstrip('_') for i in re.sub(r'\([^)]*\)', '', text).split(' ')])).lstrip().rstrip()
|
||||||
|
@ -71,11 +72,19 @@ class CaptionProcessor(object):
|
||||||
def __call__(self, sample):
|
def __call__(self, sample):
|
||||||
# preprocess caption
|
# preprocess caption
|
||||||
caption_data = json.loads(sample['caption'])
|
caption_data = json.loads(sample['caption'])
|
||||||
|
if not self.random_order:
|
||||||
character = self.get_key(caption_data, 'tag_string_character', True, self.character_rate, False, True)
|
character = self.get_key(caption_data, 'tag_string_character', True, self.character_rate, False, True)
|
||||||
copyright = self.get_key(caption_data, 'tag_string_copyright', True, self.copyright_rate, True, True)
|
copyright = self.get_key(caption_data, 'tag_string_copyright', True, self.copyright_rate, True, True)
|
||||||
artist = self.get_key(caption_data, 'tag_string_artist', True, self.artist_rate, True, True)
|
artist = self.get_key(caption_data, 'tag_string_artist', True, self.artist_rate, True, True)
|
||||||
general = self.get_key(caption_data, 'tag_string_general', True, self.general_rate, True, False)
|
general = self.get_key(caption_data, 'tag_string_general', True, self.general_rate, True, False)
|
||||||
sample['caption'] = f'{character}{copyright}{artist}{general}'.lstrip().rstrip(',')
|
tag_str = f'{character}{copyright}{artist}{general}'.lstrip().rstrip(',')
|
||||||
|
else:
|
||||||
|
character = self.get_key(caption_data, 'tag_string_character', False, self.character_rate, False)
|
||||||
|
copyright = self.get_key(caption_data, 'tag_string_copyright', False, self.copyright_rate, True, False)
|
||||||
|
artist = self.get_key(caption_data, 'tag_string_artist', False, self.artist_rate, True, False)
|
||||||
|
general = self.get_key(caption_data, 'tag_string_general', False, self.general_rate, True, False)
|
||||||
|
tag_str = self.clean(f'{character}{copyright}{artist}{general}').lstrip().rstrip(' ')
|
||||||
|
sample['caption'] = tag_str
|
||||||
|
|
||||||
# preprocess image
|
# preprocess image
|
||||||
image = sample['image']
|
image = sample['image']
|
||||||
|
@ -150,7 +159,7 @@ class DanbooruWebDataModuleFromConfig(pl.LightningDataModule):
|
||||||
transform_dict = {}
|
transform_dict = {}
|
||||||
transform_dict.update({self.image_key: image_transforms})
|
transform_dict.update({self.image_key: image_transforms})
|
||||||
|
|
||||||
postprocess = CaptionProcessor(copyright_rate=self.copyright_rate, character_rate=self.character_rate, general_rate=self.general_rate, artist_rate=self.artist_rate, normalize=self.normalize, caption_shuffle=self.caption_shuffle, transforms=image_transforms, max_size=self.max_size, resize=self.resize)
|
postprocess = CaptionProcessor(copyright_rate=self.copyright_rate, character_rate=self.character_rate, general_rate=self.general_rate, artist_rate=self.artist_rate, normalize=self.normalize, caption_shuffle=self.caption_shuffle, transforms=image_transforms, max_size=self.max_size, resize=self.resize, random_order=self.random_order)
|
||||||
|
|
||||||
|
|
||||||
tars = os.path.join(self.tar_base)
|
tars = os.path.join(self.tar_base)
|
||||||
|
|
Loading…
Reference in New Issue