Custom Width and Height
This commit is contained in:
parent
6ad3a53e36
commit
7a20f914ed
|
@ -15,13 +15,12 @@ re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
|
||||||
|
|
||||||
|
|
||||||
class PersonalizedBase(Dataset):
|
class PersonalizedBase(Dataset):
|
||||||
def __init__(self, data_root, size, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
|
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
|
||||||
|
|
||||||
self.placeholder_token = placeholder_token
|
self.placeholder_token = placeholder_token
|
||||||
|
|
||||||
self.size = size
|
self.width = width
|
||||||
self.width = size
|
self.height = height
|
||||||
self.height = size
|
|
||||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
|
||||||
self.dataset = []
|
self.dataset = []
|
||||||
|
|
|
@ -7,8 +7,9 @@ import tqdm
|
||||||
from modules import shared, images
|
from modules import shared, images
|
||||||
|
|
||||||
|
|
||||||
def preprocess(process_src, process_dst, process_size, process_flip, process_split, process_caption):
|
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption):
|
||||||
size = process_size
|
width = process_width
|
||||||
|
height = process_height
|
||||||
src = os.path.abspath(process_src)
|
src = os.path.abspath(process_src)
|
||||||
dst = os.path.abspath(process_dst)
|
dst = os.path.abspath(process_dst)
|
||||||
|
|
||||||
|
@ -55,23 +56,23 @@ def preprocess(process_src, process_dst, process_size, process_flip, process_spl
|
||||||
is_wide = ratio < 1 / 1.35
|
is_wide = ratio < 1 / 1.35
|
||||||
|
|
||||||
if process_split and is_tall:
|
if process_split and is_tall:
|
||||||
img = img.resize((size, size * img.height // img.width))
|
img = img.resize((width, height * img.height // img.width))
|
||||||
|
|
||||||
top = img.crop((0, 0, size, size))
|
top = img.crop((0, 0, width, height))
|
||||||
save_pic(top, index)
|
save_pic(top, index)
|
||||||
|
|
||||||
bot = img.crop((0, img.height - size, size, img.height))
|
bot = img.crop((0, img.height - height, width, img.height))
|
||||||
save_pic(bot, index)
|
save_pic(bot, index)
|
||||||
elif process_split and is_wide:
|
elif process_split and is_wide:
|
||||||
img = img.resize((size * img.width // img.height, size))
|
img = img.resize((width * img.width // img.height, height))
|
||||||
|
|
||||||
left = img.crop((0, 0, size, size))
|
left = img.crop((0, 0, width, height))
|
||||||
save_pic(left, index)
|
save_pic(left, index)
|
||||||
|
|
||||||
right = img.crop((img.width - size, 0, img.width, size))
|
right = img.crop((img.width - width, 0, img.width, height))
|
||||||
save_pic(right, index)
|
save_pic(right, index)
|
||||||
else:
|
else:
|
||||||
img = images.resize_image(1, img, size, size)
|
img = images.resize_image(1, img, width, height)
|
||||||
save_pic(img, index)
|
save_pic(img, index)
|
||||||
|
|
||||||
shared.state.nextjob()
|
shared.state.nextjob()
|
||||||
|
|
|
@ -6,7 +6,6 @@ import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
import html
|
import html
|
||||||
import datetime
|
import datetime
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
from modules import shared, devices, sd_hijack, processing, sd_models
|
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||||
|
@ -157,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_size, steps, num_repeats, create_image_every, save_embedding_every, template_file):
|
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file):
|
||||||
assert embedding_name, 'embedding not selected'
|
assert embedding_name, 'embedding not selected'
|
||||||
|
|
||||||
shared.state.textinfo = "Initializing textual inversion training..."
|
shared.state.textinfo = "Initializing textual inversion training..."
|
||||||
|
@ -183,7 +182,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
|
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=training_size, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
||||||
|
|
||||||
hijack = sd_hijack.model_hijack
|
hijack = sd_hijack.model_hijack
|
||||||
|
|
||||||
|
@ -227,7 +226,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
epoch_num = math.floor(embedding.step / epoch_len)
|
epoch_num = embedding.step // epoch_len
|
||||||
epoch_step = embedding.step - (epoch_num * epoch_len) + 1
|
epoch_step = embedding.step - (epoch_num * epoch_len) + 1
|
||||||
|
|
||||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
|
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
|
||||||
|
@ -243,8 +242,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
sd_model=shared.sd_model,
|
sd_model=shared.sd_model,
|
||||||
prompt=text,
|
prompt=text,
|
||||||
steps=20,
|
steps=20,
|
||||||
height=training_size,
|
height=training_height,
|
||||||
width=training_size,
|
width=training_width,
|
||||||
do_not_save_grid=True,
|
do_not_save_grid=True,
|
||||||
do_not_save_samples=True,
|
do_not_save_samples=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1029,7 +1029,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
|
|
||||||
process_src = gr.Textbox(label='Source directory')
|
process_src = gr.Textbox(label='Source directory')
|
||||||
process_dst = gr.Textbox(label='Destination directory')
|
process_dst = gr.Textbox(label='Destination directory')
|
||||||
process_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512)
|
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||||
|
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
process_flip = gr.Checkbox(label='Create flipped copies')
|
process_flip = gr.Checkbox(label='Create flipped copies')
|
||||||
|
@ -1050,7 +1051,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||||
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
||||||
training_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512)
|
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||||
|
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||||
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
||||||
num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
|
num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
|
||||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
|
@ -1095,7 +1097,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
inputs=[
|
inputs=[
|
||||||
process_src,
|
process_src,
|
||||||
process_dst,
|
process_dst,
|
||||||
process_size,
|
process_width,
|
||||||
|
process_height,
|
||||||
process_flip,
|
process_flip,
|
||||||
process_split,
|
process_split,
|
||||||
process_caption,
|
process_caption,
|
||||||
|
@ -1114,7 +1117,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
learn_rate,
|
learn_rate,
|
||||||
dataset_directory,
|
dataset_directory,
|
||||||
log_directory,
|
log_directory,
|
||||||
training_size,
|
training_width,
|
||||||
|
training_height,
|
||||||
steps,
|
steps,
|
||||||
num_repeats,
|
num_repeats,
|
||||||
create_image_every,
|
create_image_every,
|
||||||
|
|
Loading…
Reference in New Issue