From 759623142a63b9ef3e69a627c6ce61c3720ba05a Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sat, 18 Feb 2023 19:28:08 +0100 Subject: [PATCH] more cleanly fall back to random captions --- utils/sample_generator.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/utils/sample_generator.py b/utils/sample_generator.py index af12064..0154f3c 100644 --- a/utils/sample_generator.py +++ b/utils/sample_generator.py @@ -142,7 +142,7 @@ class SampleGenerator: self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars) sample_requests_json = config.get('samples', None) if sample_requests_json is None: - self.sample_requests = self._make_random_caption_sample_requests() + self.sample_requests = [] else: default_seed = config.get('seed', self.default_seed) default_size = (self.default_resolution, self.default_resolution) @@ -152,16 +152,14 @@ class SampleGenerator: size=tuple(p.get('size', default_size)), wants_random_caption=p.get('random_caption', False) ) for p in sample_requests_json] + if len(self.sample_requests) == 0: + self._make_random_caption_sample_requests() @torch.no_grad() def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int): """ generates samples at different cfg scales and saves them to disk """ - if len(self.sample_requests) == 0: - raise NotImplementedError("todo: implement random captions") - #max_prompts = min(4, len(batch["captions"])) - #sample_requests = batch["captions"][:max_prompts] logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}") pipe.set_progress_bar_config(disable=(not self.show_progress_bars))