commit
4cc3add770
|
@ -138,6 +138,7 @@ CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_s
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
||||||
|
|
||||||
|
# Always make sure inputs to unet are in correct dtype
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
|
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
|
||||||
CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
|
CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
|
||||||
|
|
||||||
|
@ -150,5 +151,6 @@ def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
|
||||||
return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
|
return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# Always make sure timestep calculation is in correct dtype
|
||||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
||||||
CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
||||||
|
|
|
@ -56,14 +56,19 @@ def is_using_v_parameterization_for_sd2(state_dict):
|
||||||
unet.eval()
|
unet.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
unet_dtype = torch.float
|
||||||
|
original_unet_dtype = devices.dtype_unet
|
||||||
|
|
||||||
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
|
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
|
||||||
unet.load_state_dict(unet_sd, strict=True)
|
unet.load_state_dict(unet_sd, strict=True)
|
||||||
unet.to(device=device, dtype=torch.float)
|
unet.to(device=device, dtype=unet_dtype)
|
||||||
|
devices.dtype_unet = unet_dtype
|
||||||
|
|
||||||
test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
|
test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
|
||||||
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
|
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
|
||||||
|
|
||||||
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
|
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
|
||||||
|
devices.dtype_unet = original_unet_dtype
|
||||||
|
|
||||||
return out < -1
|
return out < -1
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue