diff --git a/modules/devices.py b/modules/devices.py index 07bb23397..08bb26d6f 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -34,7 +34,7 @@ errors.run(enable_tf32, "Enabling TF32") device = get_optimal_device() -device_codeformer = cpu if has_mps else device +device_gfpgan = device_codeformer = cpu if device.type == 'mps' else device def randn(seed, shape): diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index bb30d7330..fcd8544a5 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -21,7 +21,7 @@ def gfpgann(): global loaded_gfpgan_model global model_path if loaded_gfpgan_model is not None: - loaded_gfpgan_model.gfpgan.to(shared.device) + loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan) return loaded_gfpgan_model if gfpgan_constructor is None: @@ -36,8 +36,8 @@ def gfpgann(): else: print("Unable to load gfpgan model!") return None - model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) - model.gfpgan.to(shared.device) + model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan) + model.gfpgan.to(devices.device_gfpgan) loaded_gfpgan_model = model return model