Merge pull request #15803 from huchenlei/checkpoint_false
[Performance 1/6] use_checkpoint = False
This commit is contained in:
commit
ad229fae43
|
@ -40,7 +40,7 @@ model:
|
|||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
use_checkpoint: False
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
|
|
|
@ -41,7 +41,7 @@ model:
|
|||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
use_checkpoint: True
|
||||
use_checkpoint: False
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
|
|
|
@ -45,7 +45,7 @@ model:
|
|||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
use_checkpoint: False
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
|
|
|
@ -21,7 +21,7 @@ model:
|
|||
params:
|
||||
adm_in_channels: 2816
|
||||
num_classes: sequential
|
||||
use_checkpoint: True
|
||||
use_checkpoint: False
|
||||
in_channels: 9
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
|
|
|
@ -40,7 +40,7 @@ model:
|
|||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
use_checkpoint: False
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
|
|
|
@ -40,7 +40,7 @@ model:
|
|||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
use_checkpoint: False
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
|
|
|
@ -4,16 +4,19 @@ import ldm.modules.attention
|
|||
import ldm.modules.diffusionmodules.openaimodel
|
||||
|
||||
|
||||
# Setting flag=False so that torch skips checking parameters.
|
||||
# parameters checking is expensive in frequent operations.
|
||||
|
||||
def BasicTransformerBlock_forward(self, x, context=None):
|
||||
return checkpoint(self._forward, x, context)
|
||||
return checkpoint(self._forward, x, context, flag=False)
|
||||
|
||||
|
||||
def AttentionBlock_forward(self, x):
|
||||
return checkpoint(self._forward, x)
|
||||
return checkpoint(self._forward, x, flag=False)
|
||||
|
||||
|
||||
def ResBlock_forward(self, x, emb):
|
||||
return checkpoint(self._forward, x, emb)
|
||||
return checkpoint(self._forward, x, emb, flag=False)
|
||||
|
||||
|
||||
stored = []
|
||||
|
|
|
@ -552,6 +552,14 @@ def repair_config(sd_config):
|
|||
karlo_path = os.path.join(paths.models_path, 'karlo')
|
||||
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
|
||||
|
||||
# Do not use checkpoint for inference.
|
||||
# This helps prevent extra performance overhead on checking parameters.
|
||||
# The perf overhead is about 100ms/it on 4090 for SDXL.
|
||||
if hasattr(sd_config.model.params, "network_config"):
|
||||
sd_config.model.params.network_config.params.use_checkpoint = False
|
||||
if hasattr(sd_config.model.params, "unet_config"):
|
||||
sd_config.model.params.unet_config.params.use_checkpoint = False
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
|
|
@ -35,7 +35,7 @@ def is_using_v_parameterization_for_sd2(state_dict):
|
|||
|
||||
with sd_disable_initialization.DisableInitialization():
|
||||
unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
|
||||
use_checkpoint=True,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
image_size=32,
|
||||
in_channels=4,
|
||||
|
|
Loading…
Reference in New Issue