diff --git a/utils/sample_generator.py b/utils/sample_generator.py index af12064..0154f3c 100644 --- a/utils/sample_generator.py +++ b/utils/sample_generator.py @@ -142,7 +142,7 @@ class SampleGenerator: self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars) sample_requests_json = config.get('samples', None) if sample_requests_json is None: - self.sample_requests = self._make_random_caption_sample_requests() + self.sample_requests = [] else: default_seed = config.get('seed', self.default_seed) default_size = (self.default_resolution, self.default_resolution) @@ -152,16 +152,14 @@ class SampleGenerator: size=tuple(p.get('size', default_size)), wants_random_caption=p.get('random_caption', False) ) for p in sample_requests_json] + if len(self.sample_requests) == 0: + self._make_random_caption_sample_requests() @torch.no_grad() def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int): """ generates samples at different cfg scales and saves them to disk """ - if len(self.sample_requests) == 0: - raise NotImplementedError("todo: implement random captions") - #max_prompts = min(4, len(batch["captions"])) - #sample_requests = batch["captions"][:max_prompts] logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}") pipe.set_progress_bar_config(disable=(not self.show_progress_bars))