fixes minor typos around run_modelmerger
This commit is contained in:
parent
d1ea518dea
commit
f2ae252987
|
@ -275,7 +275,7 @@ def create_config(ckpt_result, config_source, a, b, c):
|
||||||
shutil.copyfile(cfg, checkpoint_filename)
|
shutil.copyfile(cfg, checkpoint_filename)
|
||||||
|
|
||||||
|
|
||||||
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
||||||
|
|
||||||
|
|
||||||
def to_half(tensor, enable):
|
def to_half(tensor, enable):
|
||||||
|
@ -303,7 +303,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||||
def add_difference(theta0, theta1_2_diff, alpha):
|
def add_difference(theta0, theta1_2_diff, alpha):
|
||||||
return theta0 + (alpha * theta1_2_diff)
|
return theta0 + (alpha * theta1_2_diff)
|
||||||
|
|
||||||
def filename_weighed_sum():
|
def filename_weighted_sum():
|
||||||
a = primary_model_info.model_name
|
a = primary_model_info.model_name
|
||||||
b = secondary_model_info.model_name
|
b = secondary_model_info.model_name
|
||||||
Ma = round(1 - multiplier, 2)
|
Ma = round(1 - multiplier, 2)
|
||||||
|
@ -311,7 +311,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||||
|
|
||||||
return f"{Ma}({a}) + {Mb}({b})"
|
return f"{Ma}({a}) + {Mb}({b})"
|
||||||
|
|
||||||
def filename_add_differnece():
|
def filename_add_difference():
|
||||||
a = primary_model_info.model_name
|
a = primary_model_info.model_name
|
||||||
b = secondary_model_info.model_name
|
b = secondary_model_info.model_name
|
||||||
c = tertiary_model_info.model_name
|
c = tertiary_model_info.model_name
|
||||||
|
@ -323,8 +323,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||||
return primary_model_info.model_name
|
return primary_model_info.model_name
|
||||||
|
|
||||||
theta_funcs = {
|
theta_funcs = {
|
||||||
"Weighted sum": (filename_weighed_sum, None, weighted_sum),
|
"Weighted sum": (filename_weighted_sum, None, weighted_sum),
|
||||||
"Add difference": (filename_add_differnece, get_difference, add_difference),
|
"Add difference": (filename_add_difference, get_difference, add_difference),
|
||||||
"No interpolation": (filename_nothing, None, None),
|
"No interpolation": (filename_nothing, None, None),
|
||||||
}
|
}
|
||||||
filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
|
filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
|
||||||
|
@ -362,7 +362,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||||
shared.state.textinfo = 'Merging B and C'
|
shared.state.textinfo = 'Merging B and C'
|
||||||
shared.state.sampling_steps = len(theta_1.keys())
|
shared.state.sampling_steps = len(theta_1.keys())
|
||||||
for key in tqdm.tqdm(theta_1.keys()):
|
for key in tqdm.tqdm(theta_1.keys()):
|
||||||
if key in chckpoint_dict_skip_on_merge:
|
if key in checkpoint_dict_skip_on_merge:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if 'model' in key:
|
if 'model' in key:
|
||||||
|
@ -387,7 +387,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||||
for key in tqdm.tqdm(theta_0.keys()):
|
for key in tqdm.tqdm(theta_0.keys()):
|
||||||
if theta_1 and 'model' in key and key in theta_1:
|
if theta_1 and 'model' in key and key in theta_1:
|
||||||
|
|
||||||
if key in chckpoint_dict_skip_on_merge:
|
if key in checkpoint_dict_skip_on_merge:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
a = theta_0[key]
|
a = theta_0[key]
|
||||||
|
|
Loading…
Reference in New Issue