add unconditional guidance training

This commit is contained in:
harubaru 2022-10-12 18:28:44 -07:00
parent 4f9070af3c
commit 169b63e179
2 changed files with 6 additions and 0 deletions

View File

@ -81,12 +81,14 @@ data:
params:
size: 512
mode: "train"
ucg: 0.1 # unconditional guidance training
validation:
target: ldm.data.local.LocalBase
params:
size: 512
mode: "val"
val_split: 64
ucg: 0.1
lightning:
modelcheckpoint:

View File

@ -109,11 +109,13 @@ class LocalDanbooruBase(Dataset):
shuffle=False,
mode='train',
val_split=64,
ucg=0.1,
):
super().__init__()
self.shuffle=shuffle
self.crop = crop
self.ucg = ucg
print('Fetching data.')
@ -178,6 +180,8 @@ class LocalDanbooruBase(Dataset):
with open(text_file, 'rb') as f:
image['caption'] = f.read()
image = self.captionprocessor(image)
if random.random() < self.ucg:
image['caption'] = ''
except Exception as e:
print(f'Error with {self.examples[self.hashes[i]]["image"]} -- {e} -- skipping {i}')
return self.skip_sample(i)