Variable dropout rate
Implements variable dropout rate from #4549 Fixes hypernetwork multiplier being able to modified during training, also fixes user-errors by setting multiplier value to lower values for training. Changes function name to match torch.nn.module standard Fixes RNG reset issue when generating previews by restoring RNG state
This commit is contained in:
parent
bd4587d2f5
commit
a4a5475cfa
|
@ -39,7 +39,7 @@ class HypernetworkModule(torch.nn.Module):
|
|||
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
||||
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
||||
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
|
||||
add_layer_norm=False, activate_output=False, dropout_structure=None):
|
||||
super().__init__()
|
||||
|
||||
assert layer_structure is not None, "layer_structure must not be None"
|
||||
|
@ -64,9 +64,12 @@ class HypernetworkModule(torch.nn.Module):
|
|||
if add_layer_norm:
|
||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||
|
||||
# Add dropout except last layer
|
||||
if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
|
||||
linears.append(torch.nn.Dropout(p=0.3))
|
||||
# Everything should be now parsed into dropout structure, and applied here.
|
||||
# Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
|
||||
if dropout_structure is not None and dropout_structure[i+1] > 0:
|
||||
assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
|
||||
linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
|
||||
# Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
|
||||
|
||||
self.linear = torch.nn.Sequential(*linears)
|
||||
|
||||
|
@ -113,7 +116,7 @@ class HypernetworkModule(torch.nn.Module):
|
|||
state_dict[to] = x
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.linear(x) * self.multiplier
|
||||
return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
|
||||
|
||||
def trainables(self):
|
||||
layer_structure = []
|
||||
|
@ -126,6 +129,21 @@ class HypernetworkModule(torch.nn.Module):
|
|||
def apply_strength(value=None):
|
||||
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
||||
|
||||
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
|
||||
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
|
||||
if layer_structure is None:
|
||||
layer_structure = [1, 2, 1]
|
||||
if not use_dropout:
|
||||
return [0] * len(layer_structure)
|
||||
dropout_values = [0]
|
||||
dropout_values.extend([0.3] * (len(layer_structure) - 3))
|
||||
if last_layer_dropout:
|
||||
dropout_values.append(0.3)
|
||||
else:
|
||||
dropout_values.append(0)
|
||||
dropout_values.append(0)
|
||||
return dropout_values
|
||||
|
||||
|
||||
class Hypernetwork:
|
||||
filename = None
|
||||
|
@ -144,18 +162,22 @@ class Hypernetwork:
|
|||
self.add_layer_norm = add_layer_norm
|
||||
self.use_dropout = use_dropout
|
||||
self.activate_output = activate_output
|
||||
self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
|
||||
self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
|
||||
self.dropout_structure = kwargs.get('dropout_structure', None)
|
||||
if self.dropout_structure is None:
|
||||
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
||||
self.optimizer_name = None
|
||||
self.optimizer_state_dict = None
|
||||
self.optional_info = None
|
||||
|
||||
for size in enable_sizes or []:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
||||
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
||||
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
|
||||
)
|
||||
self.eval_mode()
|
||||
self.eval()
|
||||
|
||||
def weights(self):
|
||||
res = []
|
||||
|
@ -164,14 +186,14 @@ class Hypernetwork:
|
|||
res += layer.parameters()
|
||||
return res
|
||||
|
||||
def train_mode(self):
|
||||
def train(self, mode=True):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.train()
|
||||
layer.train(mode=mode)
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = True
|
||||
param.requires_grad = mode
|
||||
|
||||
def eval_mode(self):
|
||||
def eval(self):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.eval()
|
||||
|
@ -191,11 +213,13 @@ class Hypernetwork:
|
|||
state_dict['activation_func'] = self.activation_func
|
||||
state_dict['is_layer_norm'] = self.add_layer_norm
|
||||
state_dict['weight_initialization'] = self.weight_init
|
||||
state_dict['use_dropout'] = self.use_dropout
|
||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||
state_dict['activate_output'] = self.activate_output
|
||||
state_dict['last_layer_dropout'] = self.last_layer_dropout
|
||||
state_dict['use_dropout'] = self.use_dropout
|
||||
state_dict['dropout_structure'] = self.dropout_structure
|
||||
state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
|
||||
state_dict['optional_info'] = self.optional_info if self.optional_info else None
|
||||
|
||||
if self.optimizer_name is not None:
|
||||
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
|
||||
|
@ -215,43 +239,56 @@ class Hypernetwork:
|
|||
|
||||
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||
print(self.layer_structure)
|
||||
optional_info = state_dict.get('optional_info', None)
|
||||
if optional_info is not None:
|
||||
print(f"INFO:\n {optional_info}\n")
|
||||
self.optional_info = optional_info
|
||||
self.activation_func = state_dict.get('activation_func', None)
|
||||
print(f"Activation function is {self.activation_func}")
|
||||
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
||||
print(f"Weight initialization is {self.weight_init}")
|
||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||
print(f"Layer norm is set to {self.add_layer_norm}")
|
||||
self.use_dropout = state_dict.get('use_dropout', False)
|
||||
self.dropout_structure = state_dict.get('dropout_structure', None)
|
||||
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
|
||||
print(f"Dropout usage is set to {self.use_dropout}" )
|
||||
self.activate_output = state_dict.get('activate_output', True)
|
||||
print(f"Activate last layer is set to {self.activate_output}")
|
||||
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
||||
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
|
||||
if self.dropout_structure is None:
|
||||
print("Using previous dropout structure")
|
||||
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
||||
print(f"Dropout structure is set to {self.dropout_structure}")
|
||||
|
||||
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
|
||||
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
||||
print(f"Optimizer name is {self.optimizer_name}")
|
||||
|
||||
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
|
||||
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||
else:
|
||||
self.optimizer_state_dict = None
|
||||
if self.optimizer_state_dict:
|
||||
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
||||
print("Loaded existing optimizer from checkpoint")
|
||||
print(f"Optimizer name is {self.optimizer_name}")
|
||||
else:
|
||||
self.optimizer_name = "AdamW"
|
||||
print("No saved optimizer exists in checkpoint")
|
||||
|
||||
for size, sd in state_dict.items():
|
||||
if type(size) == int:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
||||
self.add_layer_norm, self.activate_output, self.dropout_structure),
|
||||
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
||||
self.add_layer_norm, self.activate_output, self.dropout_structure),
|
||||
)
|
||||
|
||||
self.name = state_dict.get('name', self.name)
|
||||
self.step = state_dict.get('step', 0)
|
||||
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
||||
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
||||
self.eval()
|
||||
|
||||
|
||||
def list_hypernetworks(path):
|
||||
|
@ -379,9 +416,10 @@ def report_statistics(loss_info:dict):
|
|||
print(e)
|
||||
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||
# Remove illegal characters from name.
|
||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||
assert name, "Name cannot be empty!"
|
||||
|
||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||
if not overwrite_old:
|
||||
|
@ -390,6 +428,11 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
|||
if type(layer_structure) == str:
|
||||
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
||||
|
||||
if use_dropout and dropout_structure and type(dropout_structure) == str:
|
||||
dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
|
||||
else:
|
||||
dropout_structure = [0] * len(layer_structure)
|
||||
|
||||
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
||||
name=name,
|
||||
enable_sizes=[int(x) for x in enable_sizes],
|
||||
|
@ -398,6 +441,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
|||
weight_init=weight_init,
|
||||
add_layer_norm=add_layer_norm,
|
||||
use_dropout=use_dropout,
|
||||
dropout_structure=dropout_structure
|
||||
)
|
||||
hypernet.save(fn)
|
||||
|
||||
|
@ -480,7 +524,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
weights = hypernetwork.weights()
|
||||
hypernetwork.train_mode()
|
||||
hypernetwork.train()
|
||||
|
||||
# Here we use optimizer from saved HN, or we can specify as UI option.
|
||||
if hypernetwork.optimizer_name in optimizer_dict:
|
||||
|
@ -594,7 +638,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
if images_dir is not None and steps_done % create_image_every == 0:
|
||||
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
hypernetwork.eval_mode()
|
||||
hypernetwork.eval()
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = None
|
||||
if torch.cuda.is_available():
|
||||
cuda_rng_state = torch.cuda.get_rng_state_all()
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
|
||||
|
@ -627,7 +675,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
if unload:
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
hypernetwork.train_mode()
|
||||
torch.set_rng_state(rng_state)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_rng_state_all(cuda_rng_state)
|
||||
hypernetwork.train()
|
||||
if image is not None:
|
||||
shared.state.current_image = image
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
|
@ -649,7 +700,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
finally:
|
||||
pbar.leave = False
|
||||
pbar.close()
|
||||
hypernetwork.eval_mode()
|
||||
hypernetwork.eval()
|
||||
#report_statistics(loss_dict)
|
||||
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
|
|
|
@ -9,8 +9,8 @@ from modules import devices, sd_hijack, shared
|
|||
not_available = ["hardswish", "multiheadattention"]
|
||||
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout)
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
||||
|
||||
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
|
||||
|
||||
|
|
|
@ -1268,6 +1268,7 @@ def create_ui():
|
|||
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
|
||||
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
|
||||
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
|
||||
new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'")
|
||||
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
|
||||
|
||||
with gr.Row():
|
||||
|
@ -1414,7 +1415,8 @@ def create_ui():
|
|||
new_hypernetwork_activation_func,
|
||||
new_hypernetwork_initialization_option,
|
||||
new_hypernetwork_add_layer_norm,
|
||||
new_hypernetwork_use_dropout
|
||||
new_hypernetwork_use_dropout,
|
||||
new_hypernetwork_dropout_structure
|
||||
],
|
||||
outputs=[
|
||||
train_hypernetwork_name,
|
||||
|
|
Loading…
Reference in New Issue