more cleanly fall back to random captions

This commit is contained in:
Damian Stewart 2023-02-18 19:28:08 +01:00
parent 648fe20200
commit 759623142a
1 changed files with 3 additions and 5 deletions

View File

@ -142,7 +142,7 @@ class SampleGenerator:
self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars) self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars)
sample_requests_json = config.get('samples', None) sample_requests_json = config.get('samples', None)
if sample_requests_json is None: if sample_requests_json is None:
self.sample_requests = self._make_random_caption_sample_requests() self.sample_requests = []
else: else:
default_seed = config.get('seed', self.default_seed) default_seed = config.get('seed', self.default_seed)
default_size = (self.default_resolution, self.default_resolution) default_size = (self.default_resolution, self.default_resolution)
@ -152,16 +152,14 @@ class SampleGenerator:
size=tuple(p.get('size', default_size)), size=tuple(p.get('size', default_size)),
wants_random_caption=p.get('random_caption', False) wants_random_caption=p.get('random_caption', False)
) for p in sample_requests_json] ) for p in sample_requests_json]
if len(self.sample_requests) == 0:
self._make_random_caption_sample_requests()
@torch.no_grad() @torch.no_grad()
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int): def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int):
""" """
generates samples at different cfg scales and saves them to disk generates samples at different cfg scales and saves them to disk
""" """
if len(self.sample_requests) == 0:
raise NotImplementedError("todo: implement random captions")
#max_prompts = min(4, len(batch["captions"]))
#sample_requests = batch["captions"][:max_prompts]
logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}") logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}")
pipe.set_progress_bar_config(disable=(not self.show_progress_bars)) pipe.set_progress_bar_config(disable=(not self.show_progress_bars))