From 169b63e1797be60c524fb89f5db15280b714e376 Mon Sep 17 00:00:00 2001 From: harubaru Date: Wed, 12 Oct 2022 18:28:44 -0700 Subject: [PATCH] add unconditional guidance training --- configs/stable-diffusion/v1-4-finetune-test.yaml | 2 ++ ldm/data/localdanboorubase.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/configs/stable-diffusion/v1-4-finetune-test.yaml b/configs/stable-diffusion/v1-4-finetune-test.yaml index bd446c9..d2bb03e 100644 --- a/configs/stable-diffusion/v1-4-finetune-test.yaml +++ b/configs/stable-diffusion/v1-4-finetune-test.yaml @@ -81,12 +81,14 @@ data: params: size: 512 mode: "train" + ucg: 0.1 # unconditional guidance training validation: target: ldm.data.local.LocalBase params: size: 512 mode: "val" val_split: 64 + ucg: 0.1 lightning: modelcheckpoint: diff --git a/ldm/data/localdanboorubase.py b/ldm/data/localdanboorubase.py index d5fe473..bd3e765 100644 --- a/ldm/data/localdanboorubase.py +++ b/ldm/data/localdanboorubase.py @@ -109,11 +109,13 @@ class LocalDanbooruBase(Dataset): shuffle=False, mode='train', val_split=64, + ucg=0.1, ): super().__init__() self.shuffle=shuffle self.crop = crop + self.ucg = ucg print('Fetching data.') @@ -178,6 +180,8 @@ class LocalDanbooruBase(Dataset): with open(text_file, 'rb') as f: image['caption'] = f.read() image = self.captionprocessor(image) + if random.random() < self.ucg: + image['caption'] = '' except Exception as e: print(f'Error with {self.examples[self.hashes[i]]["image"]} -- {e} -- skipping {i}') return self.skip_sample(i)