When device is MPS, use CPU for GFPGAN instead
GFPGAN will not work if the device is MPS, so default to CPU instead.
This commit is contained in:
parent
84e97a98c5
commit
bdaa36c844
|
@ -34,7 +34,7 @@ errors.run(enable_tf32, "Enabling TF32")
|
||||||
|
|
||||||
|
|
||||||
device = get_optimal_device()
|
device = get_optimal_device()
|
||||||
device_codeformer = cpu if has_mps else device
|
device_gfpgan = device_codeformer = cpu if device.type == 'mps' else device
|
||||||
|
|
||||||
|
|
||||||
def randn(seed, shape):
|
def randn(seed, shape):
|
||||||
|
|
|
@ -21,7 +21,7 @@ def gfpgann():
|
||||||
global loaded_gfpgan_model
|
global loaded_gfpgan_model
|
||||||
global model_path
|
global model_path
|
||||||
if loaded_gfpgan_model is not None:
|
if loaded_gfpgan_model is not None:
|
||||||
loaded_gfpgan_model.gfpgan.to(shared.device)
|
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
|
||||||
return loaded_gfpgan_model
|
return loaded_gfpgan_model
|
||||||
|
|
||||||
if gfpgan_constructor is None:
|
if gfpgan_constructor is None:
|
||||||
|
@ -36,8 +36,8 @@ def gfpgann():
|
||||||
else:
|
else:
|
||||||
print("Unable to load gfpgan model!")
|
print("Unable to load gfpgan model!")
|
||||||
return None
|
return None
|
||||||
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
|
||||||
model.gfpgan.to(shared.device)
|
model.gfpgan.to(devices.device_gfpgan)
|
||||||
loaded_gfpgan_model = model
|
loaded_gfpgan_model = model
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
Loading…
Reference in New Issue