isolate RNG in sample generation
This commit is contained in:
parent
759623142a
commit
230cab9e27
24
train.py
24
train.py
|
@ -57,6 +57,8 @@ from data.every_dream_validation import EveryDreamValidator
|
|||
from data.image_train_item import ImageTrainItem
|
||||
from utils.huggingface_downloader import try_download_model_from_hf
|
||||
from utils.convert_diff_to_ckpt import convert as converter
|
||||
from utils.isolate_rng import isolate_rng
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from utils.gpu import GPU
|
||||
import data.aspects as aspects
|
||||
|
@ -798,18 +800,18 @@ def main(args):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
if (global_step + 1) % args.sample_steps == 0:
|
||||
with isolate_rng():
|
||||
sample_generator.reload_config()
|
||||
sample_generator.update_random_captions(batch["captions"])
|
||||
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
diffusers_scheduler_config=reference_scheduler.config
|
||||
).to(device)
|
||||
sample_generator.generate_samples(inference_pipe, global_step)
|
||||
|
||||
sample_generator.reload_config()
|
||||
sample_generator.update_random_captions(batch["captions"])
|
||||
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
diffusers_scheduler_config=reference_scheduler.config
|
||||
).to(device)
|
||||
sample_generator.generate_samples(inference_pipe, global_step)
|
||||
|
||||
del inference_pipe
|
||||
del inference_pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
import os.path
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
import random
|
||||
from typing import Generator, Callable, Any
|
||||
|
|
Loading…
Reference in New Issue