isolate RNG in sample generation

This commit is contained in:
Damian Stewart 2023-02-18 20:18:21 +01:00
parent 759623142a
commit 230cab9e27
2 changed files with 13 additions and 12 deletions

View File

@ -57,6 +57,8 @@ from data.every_dream_validation import EveryDreamValidator
from data.image_train_item import ImageTrainItem
from utils.huggingface_downloader import try_download_model_from_hf
from utils.convert_diff_to_ckpt import convert as converter
from utils.isolate_rng import isolate_rng
if torch.cuda.is_available():
from utils.gpu import GPU
import data.aspects as aspects
@ -798,18 +800,18 @@ def main(args):
torch.cuda.empty_cache()
if (global_step + 1) % args.sample_steps == 0:
with isolate_rng():
sample_generator.reload_config()
sample_generator.update_random_captions(batch["captions"])
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
diffusers_scheduler_config=reference_scheduler.config
).to(device)
sample_generator.generate_samples(inference_pipe, global_step)
sample_generator.reload_config()
sample_generator.update_random_captions(batch["captions"])
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
diffusers_scheduler_config=reference_scheduler.config
).to(device)
sample_generator.generate_samples(inference_pipe, global_step)
del inference_pipe
del inference_pipe
gc.collect()
torch.cuda.empty_cache()

View File

@ -1,7 +1,6 @@
import json
import logging
import os.path
import traceback
from dataclasses import dataclass
import random
from typing import Generator, Callable, Any