add aspect_ratio arg to sample generation
This commit is contained in:
parent
52c722714c
commit
fe0083877f
|
@ -53,6 +53,15 @@ def chunk_list(l: list, batch_size: int,
|
|||
yield b[i:i + batch_size]
|
||||
|
||||
|
||||
def get_best_size_for_aspect_ratio(aspect_ratio, default_resolution) -> tuple[int, int]:
|
||||
sizes = []
|
||||
target_pixel_count = default_resolution * default_resolution
|
||||
for w in range(256, 1024, 64):
|
||||
for h in range(256, 1024, 64):
|
||||
if abs((w * h) - target_pixel_count) <= 128 * 64:
|
||||
sizes.append((w, h))
|
||||
best_size = min(sizes, key=lambda s: abs(1 - (aspect_ratio / (s[0] / s[1]))))
|
||||
return best_size
|
||||
|
||||
|
||||
class SampleGenerator:
|
||||
|
@ -155,23 +164,11 @@ class SampleGenerator:
|
|||
self.sample_requests = self._make_random_caption_sample_requests()
|
||||
else:
|
||||
default_seed = config.get('seed', self.default_seed)
|
||||
default_size = (self.default_resolution, self.default_resolution)
|
||||
#def make_size_from_aspect_ratio(aspect_ratio):
|
||||
# if aspect_ratio is None:
|
||||
# return None
|
||||
# target_pixel_count = self.default_resolution * self.default_resolution
|
||||
# w_ratio = aspect_ratio
|
||||
# h_ratio = 1/w_ratio
|
||||
# pixels_per_ratio_unit = target_pixel_count/(w_ratio + h_ratio)
|
||||
# w = round(w_ratio*pixels_per_ratio_unit / 64) * 64
|
||||
# h = round(h_ratio*pixels_per_ratio_unit / 64) * 64
|
||||
# return [w,h]
|
||||
|
||||
|
||||
self.sample_requests = [SampleRequest(prompt=p.get('prompt', ''),
|
||||
negative_prompt=p.get('negative_prompt', ''),
|
||||
seed=p.get('seed', default_seed),
|
||||
size=p.get('size', default_size),
|
||||
size=tuple(p.get('size', None) or
|
||||
get_best_size_for_aspect_ratio(p.get('aspect_ratio', 1), self.default_resolution)),
|
||||
wants_random_caption=p.get('random_caption', False)
|
||||
) for p in sample_requests_config]
|
||||
if len(self.sample_requests) == 0:
|
||||
|
@ -200,7 +197,6 @@ class SampleGenerator:
|
|||
desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}")
|
||||
for batch in batches:
|
||||
prompts = [p.prompt for p in batch]
|
||||
pbar.set_postfix(postfix={'prompts': prompts})
|
||||
negative_prompts = [p.negative_prompt for p in batch]
|
||||
seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30))
|
||||
for p in batch]
|
||||
|
|
Loading…
Reference in New Issue