isolate RNG in sample generation
This commit is contained in:
parent
759623142a
commit
230cab9e27
4
train.py
4
train.py
|
@ -57,6 +57,8 @@ from data.every_dream_validation import EveryDreamValidator
|
||||||
from data.image_train_item import ImageTrainItem
|
from data.image_train_item import ImageTrainItem
|
||||||
from utils.huggingface_downloader import try_download_model_from_hf
|
from utils.huggingface_downloader import try_download_model_from_hf
|
||||||
from utils.convert_diff_to_ckpt import convert as converter
|
from utils.convert_diff_to_ckpt import convert as converter
|
||||||
|
from utils.isolate_rng import isolate_rng
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
from utils.gpu import GPU
|
from utils.gpu import GPU
|
||||||
import data.aspects as aspects
|
import data.aspects as aspects
|
||||||
|
@ -798,7 +800,7 @@ def main(args):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if (global_step + 1) % args.sample_steps == 0:
|
if (global_step + 1) % args.sample_steps == 0:
|
||||||
|
with isolate_rng():
|
||||||
sample_generator.reload_config()
|
sample_generator.reload_config()
|
||||||
sample_generator.update_random_captions(batch["captions"])
|
sample_generator.update_random_captions(batch["captions"])
|
||||||
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
|
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os.path
|
import os.path
|
||||||
import traceback
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import random
|
import random
|
||||||
from typing import Generator, Callable, Any
|
from typing import Generator, Callable, Any
|
||||||
|
|
Loading…
Reference in New Issue