add unconditional guidance training
This commit is contained in:
parent
4f9070af3c
commit
169b63e179
|
@ -81,12 +81,14 @@ data:
|
||||||
params:
|
params:
|
||||||
size: 512
|
size: 512
|
||||||
mode: "train"
|
mode: "train"
|
||||||
|
ucg: 0.1 # unconditional guidance training
|
||||||
validation:
|
validation:
|
||||||
target: ldm.data.local.LocalBase
|
target: ldm.data.local.LocalBase
|
||||||
params:
|
params:
|
||||||
size: 512
|
size: 512
|
||||||
mode: "val"
|
mode: "val"
|
||||||
val_split: 64
|
val_split: 64
|
||||||
|
ucg: 0.1
|
||||||
|
|
||||||
lightning:
|
lightning:
|
||||||
modelcheckpoint:
|
modelcheckpoint:
|
||||||
|
|
|
@ -109,11 +109,13 @@ class LocalDanbooruBase(Dataset):
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
mode='train',
|
mode='train',
|
||||||
val_split=64,
|
val_split=64,
|
||||||
|
ucg=0.1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.shuffle=shuffle
|
self.shuffle=shuffle
|
||||||
self.crop = crop
|
self.crop = crop
|
||||||
|
self.ucg = ucg
|
||||||
|
|
||||||
print('Fetching data.')
|
print('Fetching data.')
|
||||||
|
|
||||||
|
@ -178,6 +180,8 @@ class LocalDanbooruBase(Dataset):
|
||||||
with open(text_file, 'rb') as f:
|
with open(text_file, 'rb') as f:
|
||||||
image['caption'] = f.read()
|
image['caption'] = f.read()
|
||||||
image = self.captionprocessor(image)
|
image = self.captionprocessor(image)
|
||||||
|
if random.random() < self.ucg:
|
||||||
|
image['caption'] = ''
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Error with {self.examples[self.hashes[i]]["image"]} -- {e} -- skipping {i}')
|
print(f'Error with {self.examples[self.hashes[i]]["image"]} -- {e} -- skipping {i}')
|
||||||
return self.skip_sample(i)
|
return self.skip_sample(i)
|
||||||
|
|
Loading…
Reference in New Issue