add aspect_ratio arg to sample generation
This commit is contained in:
parent
52c722714c
commit
fe0083877f
|
@ -53,6 +53,15 @@ def chunk_list(l: list, batch_size: int,
|
||||||
yield b[i:i + batch_size]
|
yield b[i:i + batch_size]
|
||||||
|
|
||||||
|
|
||||||
|
def get_best_size_for_aspect_ratio(aspect_ratio, default_resolution) -> tuple[int, int]:
|
||||||
|
sizes = []
|
||||||
|
target_pixel_count = default_resolution * default_resolution
|
||||||
|
for w in range(256, 1024, 64):
|
||||||
|
for h in range(256, 1024, 64):
|
||||||
|
if abs((w * h) - target_pixel_count) <= 128 * 64:
|
||||||
|
sizes.append((w, h))
|
||||||
|
best_size = min(sizes, key=lambda s: abs(1 - (aspect_ratio / (s[0] / s[1]))))
|
||||||
|
return best_size
|
||||||
|
|
||||||
|
|
||||||
class SampleGenerator:
|
class SampleGenerator:
|
||||||
|
@ -155,23 +164,11 @@ class SampleGenerator:
|
||||||
self.sample_requests = self._make_random_caption_sample_requests()
|
self.sample_requests = self._make_random_caption_sample_requests()
|
||||||
else:
|
else:
|
||||||
default_seed = config.get('seed', self.default_seed)
|
default_seed = config.get('seed', self.default_seed)
|
||||||
default_size = (self.default_resolution, self.default_resolution)
|
|
||||||
#def make_size_from_aspect_ratio(aspect_ratio):
|
|
||||||
# if aspect_ratio is None:
|
|
||||||
# return None
|
|
||||||
# target_pixel_count = self.default_resolution * self.default_resolution
|
|
||||||
# w_ratio = aspect_ratio
|
|
||||||
# h_ratio = 1/w_ratio
|
|
||||||
# pixels_per_ratio_unit = target_pixel_count/(w_ratio + h_ratio)
|
|
||||||
# w = round(w_ratio*pixels_per_ratio_unit / 64) * 64
|
|
||||||
# h = round(h_ratio*pixels_per_ratio_unit / 64) * 64
|
|
||||||
# return [w,h]
|
|
||||||
|
|
||||||
|
|
||||||
self.sample_requests = [SampleRequest(prompt=p.get('prompt', ''),
|
self.sample_requests = [SampleRequest(prompt=p.get('prompt', ''),
|
||||||
negative_prompt=p.get('negative_prompt', ''),
|
negative_prompt=p.get('negative_prompt', ''),
|
||||||
seed=p.get('seed', default_seed),
|
seed=p.get('seed', default_seed),
|
||||||
size=p.get('size', default_size),
|
size=tuple(p.get('size', None) or
|
||||||
|
get_best_size_for_aspect_ratio(p.get('aspect_ratio', 1), self.default_resolution)),
|
||||||
wants_random_caption=p.get('random_caption', False)
|
wants_random_caption=p.get('random_caption', False)
|
||||||
) for p in sample_requests_config]
|
) for p in sample_requests_config]
|
||||||
if len(self.sample_requests) == 0:
|
if len(self.sample_requests) == 0:
|
||||||
|
@ -200,7 +197,6 @@ class SampleGenerator:
|
||||||
desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}")
|
desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}")
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
prompts = [p.prompt for p in batch]
|
prompts = [p.prompt for p in batch]
|
||||||
pbar.set_postfix(postfix={'prompts': prompts})
|
|
||||||
negative_prompts = [p.negative_prompt for p in batch]
|
negative_prompts = [p.negative_prompt for p in batch]
|
||||||
seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30))
|
seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30))
|
||||||
for p in batch]
|
for p in batch]
|
||||||
|
|
Loading…
Reference in New Issue