more cleanly fall back to random captions
This commit is contained in:
parent
648fe20200
commit
759623142a
|
@ -142,7 +142,7 @@ class SampleGenerator:
|
||||||
self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars)
|
self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars)
|
||||||
sample_requests_json = config.get('samples', None)
|
sample_requests_json = config.get('samples', None)
|
||||||
if sample_requests_json is None:
|
if sample_requests_json is None:
|
||||||
self.sample_requests = self._make_random_caption_sample_requests()
|
self.sample_requests = []
|
||||||
else:
|
else:
|
||||||
default_seed = config.get('seed', self.default_seed)
|
default_seed = config.get('seed', self.default_seed)
|
||||||
default_size = (self.default_resolution, self.default_resolution)
|
default_size = (self.default_resolution, self.default_resolution)
|
||||||
|
@ -152,16 +152,14 @@ class SampleGenerator:
|
||||||
size=tuple(p.get('size', default_size)),
|
size=tuple(p.get('size', default_size)),
|
||||||
wants_random_caption=p.get('random_caption', False)
|
wants_random_caption=p.get('random_caption', False)
|
||||||
) for p in sample_requests_json]
|
) for p in sample_requests_json]
|
||||||
|
if len(self.sample_requests) == 0:
|
||||||
|
self._make_random_caption_sample_requests()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int):
|
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int):
|
||||||
"""
|
"""
|
||||||
generates samples at different cfg scales and saves them to disk
|
generates samples at different cfg scales and saves them to disk
|
||||||
"""
|
"""
|
||||||
if len(self.sample_requests) == 0:
|
|
||||||
raise NotImplementedError("todo: implement random captions")
|
|
||||||
#max_prompts = min(4, len(batch["captions"]))
|
|
||||||
#sample_requests = batch["captions"][:max_prompts]
|
|
||||||
logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}")
|
logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}")
|
||||||
|
|
||||||
pipe.set_progress_bar_config(disable=(not self.show_progress_bars))
|
pipe.set_progress_bar_config(disable=(not self.show_progress_bars))
|
||||||
|
|
Loading…
Reference in New Issue