2022-10-09 19:26:52 -06:00
import math
2022-09-19 16:13:12 -06:00
import os
2022-09-11 02:31:16 -06:00
import numpy as np
from PIL import Image
2022-09-25 17:22:12 -06:00
import torch
2022-09-27 01:44:00 -06:00
import tqdm
2022-09-25 17:22:12 -06:00
2022-09-28 15:21:54 -06:00
from modules import processing , shared , images , devices , sd_models
2022-09-11 02:31:16 -06:00
from modules . shared import opts
import modules . gfpgan_model
from modules . ui import plaintext_to_html
import modules . codeformer_model
2022-09-13 10:23:55 -06:00
import piexif
2022-09-14 06:20:05 -06:00
import piexif . helper
2022-09-28 15:59:44 -06:00
import gradio as gr
2022-09-13 10:23:55 -06:00
2022-09-11 02:31:16 -06:00
cached_images = { }
2022-10-09 19:26:52 -06:00
def run_extras ( extras_mode , resize_mode , image , image_folder , gfpgan_visibility , codeformer_visibility , codeformer_weight , upscaling_resize , upscaling_resize_w , upscaling_resize_h , upscaling_crop , extras_upscaler_1 , extras_upscaler_2 , extras_upscaler_2_visibility ) :
2022-09-11 14:24:24 -06:00
devices . torch_gc ( )
2022-09-11 02:31:16 -06:00
2022-09-15 21:23:37 -06:00
imageArr = [ ]
2022-09-19 16:13:12 -06:00
# Also keep track of original file names
imageNameArr = [ ]
2022-09-12 09:59:53 -06:00
2022-09-22 03:11:48 -06:00
if extras_mode == 1 :
2022-09-15 21:23:37 -06:00
#convert file to pillow image
for img in image_folder :
2022-10-09 07:14:56 -06:00
image = Image . open ( img )
2022-09-15 21:23:37 -06:00
imageArr . append ( image )
2022-09-19 16:13:12 -06:00
imageNameArr . append ( os . path . splitext ( img . orig_name ) [ 0 ] )
2022-09-22 03:11:48 -06:00
else :
imageArr . append ( image )
imageNameArr . append ( None )
2022-09-11 02:31:16 -06:00
outpath = opts . outdir_samples or opts . outdir_extras_samples
2022-09-16 03:43:24 -06:00
outputs = [ ]
2022-09-19 16:13:12 -06:00
for image , image_name in zip ( imageArr , imageNameArr ) :
2022-09-26 08:29:50 -06:00
if image is None :
return outputs , " Please select an input image. " , ' '
2022-09-15 21:23:37 -06:00
existing_pnginfo = image . info or { }
image = image . convert ( " RGB " )
info = " "
if gfpgan_visibility > 0 :
restored_img = modules . gfpgan_model . gfpgan_fix_faces ( np . array ( image , dtype = np . uint8 ) )
res = Image . fromarray ( restored_img )
2022-09-11 02:31:16 -06:00
2022-09-15 21:23:37 -06:00
if gfpgan_visibility < 1.0 :
res = Image . blend ( image , res , gfpgan_visibility )
2022-09-11 02:31:16 -06:00
2022-09-15 21:23:37 -06:00
info + = f " GFPGAN visibility: { round ( gfpgan_visibility , 2 ) } \n "
image = res
2022-09-11 02:31:16 -06:00
2022-09-15 21:23:37 -06:00
if codeformer_visibility > 0 :
restored_img = modules . codeformer_model . codeformer . restore ( np . array ( image , dtype = np . uint8 ) , w = codeformer_weight )
res = Image . fromarray ( restored_img )
2022-09-11 02:31:16 -06:00
2022-09-15 21:23:37 -06:00
if codeformer_visibility < 1.0 :
res = Image . blend ( image , res , codeformer_visibility )
2022-09-11 02:31:16 -06:00
2022-09-17 13:02:46 -06:00
info + = f " CodeFormer w: { round ( codeformer_weight , 2 ) } , CodeFormer visibility: { round ( codeformer_visibility , 2 ) } \n "
2022-09-15 21:23:37 -06:00
image = res
2022-09-11 02:31:16 -06:00
2022-10-09 19:26:52 -06:00
if resize_mode == 1 :
upscaling_resize = max ( upscaling_resize_w / image . width , upscaling_resize_h / image . height )
crop_info = " (crop) " if upscaling_crop else " "
info + = f " Resize to: { upscaling_resize_w : g } x { upscaling_resize_h : g } { crop_info } \n "
def crop_upscaled_center ( image , resize_w , resize_h ) :
left = int ( math . ceil ( ( image . width - resize_w ) / 2 ) )
right = image . width - int ( math . floor ( ( image . width - resize_w ) / 2 ) )
top = int ( math . ceil ( ( image . height - resize_h ) / 2 ) )
bottom = image . height - int ( math . floor ( ( image . height - resize_h ) / 2 ) )
image = image . crop ( ( left , top , right , bottom ) )
return image
2022-09-15 21:23:37 -06:00
if upscaling_resize != 1.0 :
2022-10-09 19:26:52 -06:00
def upscale ( image , scaler_index , resize , mode , resize_w , resize_h , crop ) :
2022-09-15 21:23:37 -06:00
small = image . crop ( ( image . width / / 2 , image . height / / 2 , image . width / / 2 + 10 , image . height / / 2 + 10 ) )
pixels = tuple ( np . array ( small ) . flatten ( ) . tolist ( ) )
key = ( resize , scaler_index , image . width , image . height , gfpgan_visibility , codeformer_visibility , codeformer_weight ) + pixels
2022-09-11 02:31:16 -06:00
2022-09-15 21:23:37 -06:00
c = cached_images . get ( key )
if c is None :
upscaler = shared . sd_upscalers [ scaler_index ]
2022-09-30 02:42:40 -06:00
c = upscaler . scaler . upscale ( image , resize , upscaler . data_path )
2022-10-09 19:26:52 -06:00
if mode == 1 and crop :
c = crop_upscaled_center ( c , resize_w , resize_h )
2022-09-15 21:23:37 -06:00
cached_images [ key ] = c
2022-09-11 02:31:16 -06:00
2022-09-15 21:23:37 -06:00
return c
2022-09-11 02:31:16 -06:00
2022-09-15 21:23:37 -06:00
info + = f " Upscale: { round ( upscaling_resize , 3 ) } , model: { shared . sd_upscalers [ extras_upscaler_1 ] . name } \n "
2022-10-09 19:26:52 -06:00
res = upscale ( image , extras_upscaler_1 , upscaling_resize , resize_mode , upscaling_resize_w , upscaling_resize_h , upscaling_crop )
2022-09-11 02:31:16 -06:00
2022-09-15 21:23:37 -06:00
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0 :
2022-10-09 19:26:52 -06:00
res2 = upscale ( image , extras_upscaler_2 , upscaling_resize , resize_mode , upscaling_resize_w , upscaling_resize_h , upscaling_crop )
2022-09-15 21:23:37 -06:00
info + = f " Upscale: { round ( upscaling_resize , 3 ) } , visibility: { round ( extras_upscaler_2_visibility , 3 ) } , model: { shared . sd_upscalers [ extras_upscaler_2 ] . name } \n "
res = Image . blend ( res , res2 , extras_upscaler_2_visibility )
2022-09-11 02:31:16 -06:00
2022-09-15 21:23:37 -06:00
image = res
2022-09-11 02:31:16 -06:00
2022-09-15 21:23:37 -06:00
while len ( cached_images ) > 2 :
del cached_images [ next ( iter ( cached_images . keys ( ) ) ) ]
2022-09-11 02:31:16 -06:00
2022-09-19 16:13:12 -06:00
images . save_image ( image , path = outpath , basename = " " , seed = None , prompt = None , extension = opts . samples_format , info = info , short_filename = True ,
no_prompt = True , grid = False , pnginfo_section_name = " extras " , existing_info = existing_pnginfo ,
forced_filename = image_name if opts . use_original_name_batch else None )
2022-09-11 02:31:16 -06:00
2022-10-09 04:10:15 -06:00
if opts . enable_pnginfo :
image . info = existing_pnginfo
image . info [ " extras " ] = info
2022-09-16 03:43:24 -06:00
outputs . append ( image )
2022-09-28 19:14:13 -06:00
devices . torch_gc ( )
2022-09-16 03:43:24 -06:00
return outputs , plaintext_to_html ( info ) , ' '
2022-09-11 02:31:16 -06:00
2022-09-17 00:07:07 -06:00
def run_pnginfo ( image ) :
2022-09-19 11:18:16 -06:00
if image is None :
return ' ' , ' ' , ' '
2022-09-13 10:23:55 -06:00
items = image . info
2022-09-23 13:49:21 -06:00
geninfo = ' '
2022-09-13 10:23:55 -06:00
if " exif " in image . info :
exif = piexif . load ( image . info [ " exif " ] )
exif_comment = ( exif or { } ) . get ( " Exif " , { } ) . get ( piexif . ExifIFD . UserComment , b ' ' )
2022-09-14 06:20:05 -06:00
try :
exif_comment = piexif . helper . UserComment . load ( exif_comment )
except ValueError :
exif_comment = exif_comment . decode ( ' utf8 ' , errors = " ignore " )
2022-09-13 10:23:55 -06:00
items [ ' exif comment ' ] = exif_comment
2022-09-23 13:49:21 -06:00
geninfo = exif_comment
2022-09-13 10:23:55 -06:00
2022-09-16 14:48:22 -06:00
for field in [ ' jfif ' , ' jfif_version ' , ' jfif_unit ' , ' jfif_density ' , ' dpi ' , ' exif ' ,
' loop ' , ' background ' , ' timestamp ' , ' duration ' ] :
items . pop ( field , None )
2022-09-13 10:23:55 -06:00
2022-09-23 13:49:21 -06:00
geninfo = items . get ( ' parameters ' , geninfo )
2022-09-13 10:23:55 -06:00
2022-09-11 02:31:16 -06:00
info = ' '
2022-09-13 10:23:55 -06:00
for key , text in items . items ( ) :
2022-09-11 02:31:16 -06:00
info + = f """
< div >
< p > < b > { plaintext_to_html ( str ( key ) ) } < / b > < / p >
< p > { plaintext_to_html ( str ( text ) ) } < / p >
< / div >
""" .strip()+ " \n "
if len ( info ) == 0 :
message = " Nothing found in the image. "
info = f " <div><p> { message } <p></div> "
2022-09-23 13:49:21 -06:00
return ' ' , geninfo , info
2022-09-25 17:22:12 -06:00
2022-09-28 17:50:34 -06:00
def run_modelmerger ( primary_model_name , secondary_model_name , interp_method , interp_amount , save_as_half , custom_name ) :
2022-09-26 08:50:21 -06:00
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
def weighted_sum ( theta0 , theta1 , alpha ) :
return ( ( 1 - alpha ) * theta0 ) + ( alpha * theta1 )
# Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
def sigmoid ( theta0 , theta1 , alpha ) :
alpha = alpha * alpha * ( 3 - ( 2 * alpha ) )
return theta0 + ( ( theta1 - theta0 ) * alpha )
2022-09-28 06:52:46 -06:00
# Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
def inv_sigmoid ( theta0 , theta1 , alpha ) :
import math
alpha = 0.5 - math . sin ( math . asin ( 1.0 - 2.0 * alpha ) / 3.0 )
return theta0 + ( ( theta1 - theta0 ) * alpha )
2022-09-28 15:59:44 -06:00
primary_model_info = sd_models . checkpoints_list [ primary_model_name ]
secondary_model_info = sd_models . checkpoints_list [ secondary_model_name ]
2022-09-27 01:44:00 -06:00
2022-09-28 15:59:44 -06:00
print ( f " Loading { primary_model_info . filename } ... " )
primary_model = torch . load ( primary_model_info . filename , map_location = ' cpu ' )
2022-09-27 19:34:24 -06:00
2022-09-28 15:59:44 -06:00
print ( f " Loading { secondary_model_info . filename } ... " )
secondary_model = torch . load ( secondary_model_info . filename , map_location = ' cpu ' )
2022-10-09 01:23:31 -06:00
theta_0 = sd_models . get_state_dict_from_checkpoint ( primary_model )
theta_1 = sd_models . get_state_dict_from_checkpoint ( secondary_model )
2022-09-27 01:44:00 -06:00
theta_funcs = {
" Weighted Sum " : weighted_sum ,
" Sigmoid " : sigmoid ,
2022-09-28 15:21:54 -06:00
" Inverse Sigmoid " : inv_sigmoid ,
2022-09-27 01:44:00 -06:00
}
theta_func = theta_funcs [ interp_method ]
print ( f " Merging... " )
for key in tqdm . tqdm ( theta_0 . keys ( ) ) :
2022-09-25 17:22:12 -06:00
if ' model ' in key and key in theta_1 :
2022-09-28 15:59:44 -06:00
theta_0 [ key ] = theta_func ( theta_0 [ key ] , theta_1 [ key ] , ( float ( 1.0 ) - interp_amount ) ) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
if save_as_half :
theta_0 [ key ] = theta_0 [ key ] . half ( )
2022-10-09 19:26:52 -06:00
2022-09-25 17:22:12 -06:00
for key in theta_1 . keys ( ) :
if ' model ' in key and key not in theta_0 :
theta_0 [ key ] = theta_1 [ key ]
2022-09-28 15:59:44 -06:00
if save_as_half :
theta_0 [ key ] = theta_0 [ key ] . half ( )
2022-09-27 01:44:00 -06:00
2022-09-30 13:57:25 -06:00
ckpt_dir = shared . cmd_opts . ckpt_dir or sd_models . model_path
2022-09-28 15:59:44 -06:00
filename = primary_model_info . model_name + ' _ ' + str ( round ( interp_amount , 2 ) ) + ' - ' + secondary_model_info . model_name + ' _ ' + str ( round ( ( float ( 1.0 ) - interp_amount ) , 2 ) ) + ' - ' + interp_method . replace ( " " , " _ " ) + ' -merged.ckpt '
2022-09-28 17:50:34 -06:00
filename = filename if custom_name == ' ' else ( custom_name + ' .ckpt ' )
2022-09-30 13:57:25 -06:00
output_modelname = os . path . join ( ckpt_dir , filename )
2022-09-28 15:21:54 -06:00
2022-09-27 01:44:00 -06:00
print ( f " Saving to { output_modelname } ... " )
2022-09-27 19:34:24 -06:00
torch . save ( primary_model , output_modelname )
2022-09-27 01:44:00 -06:00
2022-09-28 15:59:44 -06:00
sd_models . list_models ( )
2022-09-27 01:44:00 -06:00
print ( f " Checkpoint saved. " )
2022-09-28 15:59:44 -06:00
return [ " Checkpoint saved to " + output_modelname ] + [ gr . Dropdown . update ( choices = sd_models . checkpoint_tiles ( ) ) for _ in range ( 3 ) ]