isolate RNG in sample generation

This commit is contained in:
Damian Stewart 2023-02-18 20:18:21 +01:00
parent 759623142a
commit 230cab9e27
2 changed files with 13 additions and 12 deletions

View File

@ -57,6 +57,8 @@ from data.every_dream_validation import EveryDreamValidator
from data.image_train_item import ImageTrainItem from data.image_train_item import ImageTrainItem
from utils.huggingface_downloader import try_download_model_from_hf from utils.huggingface_downloader import try_download_model_from_hf
from utils.convert_diff_to_ckpt import convert as converter from utils.convert_diff_to_ckpt import convert as converter
from utils.isolate_rng import isolate_rng
if torch.cuda.is_available(): if torch.cuda.is_available():
from utils.gpu import GPU from utils.gpu import GPU
import data.aspects as aspects import data.aspects as aspects
@ -798,7 +800,7 @@ def main(args):
torch.cuda.empty_cache() torch.cuda.empty_cache()
if (global_step + 1) % args.sample_steps == 0: if (global_step + 1) % args.sample_steps == 0:
with isolate_rng():
sample_generator.reload_config() sample_generator.reload_config()
sample_generator.update_random_captions(batch["captions"]) sample_generator.update_random_captions(batch["captions"])
inference_pipe = sample_generator.create_inference_pipe(unet=unet, inference_pipe = sample_generator.create_inference_pipe(unet=unet,

View File

@ -1,7 +1,6 @@
import json import json
import logging import logging
import os.path import os.path
import traceback
from dataclasses import dataclass from dataclasses import dataclass
import random import random
from typing import Generator, Callable, Any from typing import Generator, Callable, Any