more cleanly fall back to random captions
This commit is contained in:
parent
648fe20200
commit
759623142a
|
@ -142,7 +142,7 @@ class SampleGenerator:
|
|||
self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars)
|
||||
sample_requests_json = config.get('samples', None)
|
||||
if sample_requests_json is None:
|
||||
self.sample_requests = self._make_random_caption_sample_requests()
|
||||
self.sample_requests = []
|
||||
else:
|
||||
default_seed = config.get('seed', self.default_seed)
|
||||
default_size = (self.default_resolution, self.default_resolution)
|
||||
|
@ -152,16 +152,14 @@ class SampleGenerator:
|
|||
size=tuple(p.get('size', default_size)),
|
||||
wants_random_caption=p.get('random_caption', False)
|
||||
) for p in sample_requests_json]
|
||||
if len(self.sample_requests) == 0:
|
||||
self._make_random_caption_sample_requests()
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int):
|
||||
"""
|
||||
generates samples at different cfg scales and saves them to disk
|
||||
"""
|
||||
if len(self.sample_requests) == 0:
|
||||
raise NotImplementedError("todo: implement random captions")
|
||||
#max_prompts = min(4, len(batch["captions"]))
|
||||
#sample_requests = batch["captions"][:max_prompts]
|
||||
logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}")
|
||||
|
||||
pipe.set_progress_bar_config(disable=(not self.show_progress_bars))
|
||||
|
|
Loading…
Reference in New Issue