add unconditional guidance training

This commit is contained in:
harubaru 2022-10-12 18:28:44 -07:00
parent 4f9070af3c
commit 169b63e179
2 changed files with 6 additions and 0 deletions

View File

@ -81,12 +81,14 @@ data:
params: params:
size: 512 size: 512
mode: "train" mode: "train"
ucg: 0.1 # unconditional guidance training
validation: validation:
target: ldm.data.local.LocalBase target: ldm.data.local.LocalBase
params: params:
size: 512 size: 512
mode: "val" mode: "val"
val_split: 64 val_split: 64
ucg: 0.1
lightning: lightning:
modelcheckpoint: modelcheckpoint:

View File

@ -109,11 +109,13 @@ class LocalDanbooruBase(Dataset):
shuffle=False, shuffle=False,
mode='train', mode='train',
val_split=64, val_split=64,
ucg=0.1,
): ):
super().__init__() super().__init__()
self.shuffle=shuffle self.shuffle=shuffle
self.crop = crop self.crop = crop
self.ucg = ucg
print('Fetching data.') print('Fetching data.')
@ -178,6 +180,8 @@ class LocalDanbooruBase(Dataset):
with open(text_file, 'rb') as f: with open(text_file, 'rb') as f:
image['caption'] = f.read() image['caption'] = f.read()
image = self.captionprocessor(image) image = self.captionprocessor(image)
if random.random() < self.ucg:
image['caption'] = ''
except Exception as e: except Exception as e:
print(f'Error with {self.examples[self.hashes[i]]["image"]} -- {e} -- skipping {i}') print(f'Error with {self.examples[self.hashes[i]]["image"]} -- {e} -- skipping {i}')
return self.skip_sample(i) return self.skip_sample(i)