add aspect_ratio arg to sample generation

This commit is contained in:
Damian Stewart 2023-03-02 22:16:21 +01:00
parent 52c722714c
commit fe0083877f
1 changed files with 11 additions and 15 deletions

View File

@ -53,6 +53,15 @@ def chunk_list(l: list, batch_size: int,
yield b[i:i + batch_size] yield b[i:i + batch_size]
def get_best_size_for_aspect_ratio(aspect_ratio, default_resolution) -> tuple[int, int]:
sizes = []
target_pixel_count = default_resolution * default_resolution
for w in range(256, 1024, 64):
for h in range(256, 1024, 64):
if abs((w * h) - target_pixel_count) <= 128 * 64:
sizes.append((w, h))
best_size = min(sizes, key=lambda s: abs(1 - (aspect_ratio / (s[0] / s[1]))))
return best_size
class SampleGenerator: class SampleGenerator:
@ -155,23 +164,11 @@ class SampleGenerator:
self.sample_requests = self._make_random_caption_sample_requests() self.sample_requests = self._make_random_caption_sample_requests()
else: else:
default_seed = config.get('seed', self.default_seed) default_seed = config.get('seed', self.default_seed)
default_size = (self.default_resolution, self.default_resolution)
#def make_size_from_aspect_ratio(aspect_ratio):
# if aspect_ratio is None:
# return None
# target_pixel_count = self.default_resolution * self.default_resolution
# w_ratio = aspect_ratio
# h_ratio = 1/w_ratio
# pixels_per_ratio_unit = target_pixel_count/(w_ratio + h_ratio)
# w = round(w_ratio*pixels_per_ratio_unit / 64) * 64
# h = round(h_ratio*pixels_per_ratio_unit / 64) * 64
# return [w,h]
self.sample_requests = [SampleRequest(prompt=p.get('prompt', ''), self.sample_requests = [SampleRequest(prompt=p.get('prompt', ''),
negative_prompt=p.get('negative_prompt', ''), negative_prompt=p.get('negative_prompt', ''),
seed=p.get('seed', default_seed), seed=p.get('seed', default_seed),
size=p.get('size', default_size), size=tuple(p.get('size', None) or
get_best_size_for_aspect_ratio(p.get('aspect_ratio', 1), self.default_resolution)),
wants_random_caption=p.get('random_caption', False) wants_random_caption=p.get('random_caption', False)
) for p in sample_requests_config] ) for p in sample_requests_config]
if len(self.sample_requests) == 0: if len(self.sample_requests) == 0:
@ -200,7 +197,6 @@ class SampleGenerator:
desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}") desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}")
for batch in batches: for batch in batches:
prompts = [p.prompt for p in batch] prompts = [p.prompt for p in batch]
pbar.set_postfix(postfix={'prompts': prompts})
negative_prompts = [p.negative_prompt for p in batch] negative_prompts = [p.negative_prompt for p in batch]
seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30)) seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30))
for p in batch] for p in batch]