Merge branch 'main' into refactor-data-resolution
This commit is contained in:
commit
b6c7299baf
|
@ -12,12 +12,17 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "blaLMSbkPHhG"
|
"id": "blaLMSbkPHhG"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"EveryDream2 Colab Edition"
|
"# EveryDream2 Colab Edition\n",
|
||||||
|
"\n",
|
||||||
|
"Check out documentation here: https://github.com/victorchall/EveryDream2trainer#docs\n",
|
||||||
|
"\n",
|
||||||
|
"And join the discord: https://discord.gg/uheqxU6sXN"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -225,7 +230,7 @@
|
||||||
"#@title Resume from a diffusers model saved to your Gdrive\n",
|
"#@title Resume from a diffusers model saved to your Gdrive\n",
|
||||||
"#@markdown * if you have preveiously saved diffusers on your drive you can slect them here\n",
|
"#@markdown * if you have preveiously saved diffusers on your drive you can slect them here\n",
|
||||||
"#@markdown ex. */content/drive/MyDrive/everydreamlogs/myproject_202208/ckpts/interrupted-gs023*\n",
|
"#@markdown ex. */content/drive/MyDrive/everydreamlogs/myproject_202208/ckpts/interrupted-gs023*\n",
|
||||||
"Resume_Model = \"\" #@param{type:\"string\"} \n",
|
"Resume_Model = \"/content/drive/MyDrive/everydreamlogs/ckpt/SD15\" #@param{type:\"string\"} \n",
|
||||||
"save_name= Resume_Model"
|
"save_name= Resume_Model"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -268,7 +273,8 @@
|
||||||
"#@markdown * Name your project so you can find it in your logs\n",
|
"#@markdown * Name your project so you can find it in your logs\n",
|
||||||
"Project_Name = \"my_project\" #@param{type: 'string'}\n",
|
"Project_Name = \"my_project\" #@param{type: 'string'}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"#@markdown * The learning rate affects how much \"training\" is done on the model per training step. It is a very careful balance to select a value that will learn your data. See Advanced Tweaking for more info. Once you have started, the learning rate is a good first knob to turn as you move into more advanced tweaking.\n",
|
"#@markdown * The learning rate affects how much \"training\" is done on the model per training step. It is a very careful balance to select a value that will learn your data and not wreck the model. \n",
|
||||||
|
"#@markdown Leave this default unless you are very comfortable with training and know what you are doing.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Learning_Rate = 1e-6 #@param{type: 'number'}\n",
|
"Learning_Rate = 1e-6 #@param{type: 'number'}\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -294,6 +300,8 @@
|
||||||
"#@markdown * Remember more gradient accumulation (or batch size) doesn't automatically mean better\n",
|
"#@markdown * Remember more gradient accumulation (or batch size) doesn't automatically mean better\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Gradient_steps = 1 #@param{type:\"slider\", min:1, max:10, step:1}\n",
|
"Gradient_steps = 1 #@param{type:\"slider\", min:1, max:10, step:1}\n",
|
||||||
|
"\n",
|
||||||
|
"#@markdown * Location on your Gdrive where your training images are.\n",
|
||||||
"Dataset_Location = \"/content/drive/MyDrive/training_samples\" #@param {type:\"string\"}\n",
|
"Dataset_Location = \"/content/drive/MyDrive/training_samples\" #@param {type:\"string\"}\n",
|
||||||
"dataset = Dataset_Location\n",
|
"dataset = Dataset_Location\n",
|
||||||
"model = save_name\n",
|
"model = save_name\n",
|
||||||
|
@ -309,8 +317,11 @@
|
||||||
"\n",
|
"\n",
|
||||||
"#@markdown You can set your own sample prompts by adding them, one line at a time, to `/content/EveryDream2trainer/sample_prompts.txt`. If left empty, it will use the captions from your training images.\n",
|
"#@markdown You can set your own sample prompts by adding them, one line at a time, to `/content/EveryDream2trainer/sample_prompts.txt`. If left empty, it will use the captions from your training images.\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"#@markdown Use the steps_between_samples to set how often the samples are generated.\n",
|
||||||
"Steps_between_samples = 300 #@param{type:\"integer\"}\n",
|
"Steps_between_samples = 300 #@param{type:\"integer\"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"#@markdown * That's it! Run the cell!\n",
|
||||||
|
"\n",
|
||||||
"Drive=\"\"\n",
|
"Drive=\"\"\n",
|
||||||
"if Save_to_Gdrive:\n",
|
"if Save_to_Gdrive:\n",
|
||||||
" Drive = \"--logdir /content/drive/MyDrive/everydreamlogs --save_ckpt_dir /content/drive/MyDrive/everydreamlogs/ckpt\"\n",
|
" Drive = \"--logdir /content/drive/MyDrive/everydreamlogs --save_ckpt_dir /content/drive/MyDrive/everydreamlogs/ckpt\"\n",
|
||||||
|
@ -347,7 +358,7 @@
|
||||||
" $DX \\\n",
|
" $DX \\\n",
|
||||||
" --amp \\\n",
|
" --amp \\\n",
|
||||||
" --batch_size $Batch_Size \\\n",
|
" --batch_size $Batch_Size \\\n",
|
||||||
" --grad_accum 2 \\\n",
|
" --grad_accum $Gradient_steps \\\n",
|
||||||
" --cond_dropout 0.00 \\\n",
|
" --cond_dropout 0.00 \\\n",
|
||||||
" --data_root \"$dataset\" \\\n",
|
" --data_root \"$dataset\" \\\n",
|
||||||
" --flip_p 0.00 \\\n",
|
" --flip_p 0.00 \\\n",
|
||||||
|
@ -376,8 +387,21 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"#@title Alternate startup script\n",
|
"#@title Alternate startup script\n",
|
||||||
"#@markdown Edit train.json to setup your paramaters\n",
|
"#@markdown * Edit train.json to setup your paramaters\n",
|
||||||
"!python train.py --config train.json"
|
"#@markdown * Edit chain0.json to make use of chaining\n",
|
||||||
|
"#@markdown * make sure to check each confguration you will need 1 Json per chain length 3 are provided\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"%cd /content/EveryDream2trainer\n",
|
||||||
|
"Chain_Length=0 #@param{type:\"integer\"}\n",
|
||||||
|
"l = Chain_Length \n",
|
||||||
|
"I=0 #repeat counter\n",
|
||||||
|
"if l == None or l == 0:\n",
|
||||||
|
" l=1\n",
|
||||||
|
"while l > 0:\n",
|
||||||
|
" !python train_colab.py --config chain{I}.json\n",
|
||||||
|
" l -= 1\n",
|
||||||
|
" I =+ 1"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
|
@ -17,6 +17,7 @@ import bisect
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
import copy
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from data.image_train_item import ImageTrainItem
|
from data.image_train_item import ImageTrainItem
|
||||||
|
@ -46,7 +47,6 @@ class DataLoaderMultiAspect():
|
||||||
self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False)
|
self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False)
|
||||||
|
|
||||||
logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
|
logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
|
||||||
logging.info(" Preloading images...")
|
|
||||||
self.__prepare_train_data()
|
self.__prepare_train_data()
|
||||||
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
|
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
|
||||||
|
|
||||||
|
@ -55,32 +55,37 @@ class DataLoaderMultiAspect():
|
||||||
"""
|
"""
|
||||||
Deals with multiply.txt whole and fractional numbers
|
Deals with multiply.txt whole and fractional numbers
|
||||||
"""
|
"""
|
||||||
prepared_train_data_local = self.prepared_train_data.copy()
|
#print(f"Picking multiplied set from {len(self.prepared_train_data)}")
|
||||||
|
data_copy = copy.deepcopy(self.prepared_train_data) # deep copy to avoid modifying original multiplier property
|
||||||
epoch_size = len(self.prepared_train_data)
|
epoch_size = len(self.prepared_train_data)
|
||||||
picked_images = []
|
picked_images = []
|
||||||
|
|
||||||
# add by whole number part first and decrement multiplier in copy
|
# add by whole number part first and decrement multiplier in copy
|
||||||
for iti in prepared_train_data_local:
|
for iti in data_copy:
|
||||||
|
#print(f"check for whole number {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}")
|
||||||
while iti.multiplier >= 1.0:
|
while iti.multiplier >= 1.0:
|
||||||
picked_images.append(iti)
|
picked_images.append(iti)
|
||||||
iti.multiplier -= 1
|
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}, , datalen: {len(picked_images)}")
|
||||||
if iti.multiplier == 0:
|
iti.multiplier -= 1.0
|
||||||
prepared_train_data_local.remove(iti)
|
|
||||||
|
|
||||||
remaining = epoch_size - len(picked_images)
|
remaining = epoch_size - len(picked_images)
|
||||||
|
|
||||||
assert remaining >= 0, "Something went wrong with the multiplier calculation"
|
assert remaining >= 0, "Something went wrong with the multiplier calculation"
|
||||||
|
#print(f"Remaining to fill epoch after whole number adds: {remaining}")
|
||||||
|
#print(f"Remaining in data copy: {len(data_copy)}")
|
||||||
|
|
||||||
# add by renaming fractional numbers by random chance
|
# add by renaming fractional numbers by random chance
|
||||||
while remaining > 0:
|
while remaining > 0:
|
||||||
for iti in prepared_train_data_local:
|
for iti in data_copy:
|
||||||
if randomizer.uniform(0.0, 1) < iti.multiplier:
|
if randomizer.uniform(0.0, 1.0) < iti.multiplier:
|
||||||
|
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {remaining}, datalen: {len(data_copy)}")
|
||||||
picked_images.append(iti)
|
picked_images.append(iti)
|
||||||
remaining -= 1
|
remaining -= 1
|
||||||
prepared_train_data_local.remove(iti)
|
data_copy.remove(iti)
|
||||||
if remaining <= 0:
|
if remaining <= 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
del data_copy
|
||||||
return picked_images
|
return picked_images
|
||||||
|
|
||||||
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0):
|
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0):
|
||||||
|
@ -150,9 +155,17 @@ class DataLoaderMultiAspect():
|
||||||
"""
|
"""
|
||||||
Create ImageTrainItem objects with metadata for hydration later
|
Create ImageTrainItem objects with metadata for hydration later
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not self.has_scanned:
|
if not self.has_scanned:
|
||||||
self.has_scanned = True
|
self.has_scanned = True
|
||||||
self.prepared_train_data, events = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p)
|
|
||||||
|
logging.info(" Preloading images...")
|
||||||
|
|
||||||
|
items, events = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p, seed=self.seed)
|
||||||
|
image_paths = set(map(lambda item: item.pathname, items))
|
||||||
|
|
||||||
|
print (f" * DLMA: {len(items)} images loaded from {len(image_paths)} files")
|
||||||
|
|
||||||
random.Random(self.seed).shuffle(self.prepared_train_data)
|
random.Random(self.seed).shuffle(self.prepared_train_data)
|
||||||
self.__report_undersized_images(events)
|
self.__report_undersized_images(events)
|
||||||
|
|
||||||
|
|
|
@ -71,7 +71,6 @@ class EveryDreamBatch(Dataset):
|
||||||
self.rated_dataset = rated_dataset
|
self.rated_dataset = rated_dataset
|
||||||
self.rated_dataset_dropout_target = rated_dataset_dropout_target
|
self.rated_dataset_dropout_target = rated_dataset_dropout_target
|
||||||
|
|
||||||
|
|
||||||
if seed == -1:
|
if seed == -1:
|
||||||
seed = random.randint(0, 99999)
|
seed = random.randint(0, 99999)
|
||||||
|
|
||||||
|
|
|
@ -105,6 +105,10 @@ class ImageCaption:
|
||||||
caption += ", "
|
caption += ", "
|
||||||
caption += tag
|
caption += tag
|
||||||
|
|
||||||
|
if caption:
|
||||||
|
caption += ", "
|
||||||
|
caption += tag
|
||||||
|
|
||||||
return caption
|
return caption
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import typing
|
import typing
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
|
@ -25,7 +26,7 @@ class UndersizedImageEvent(Event):
|
||||||
self.target_size = target_size
|
self.target_size = target_size
|
||||||
|
|
||||||
class DataResolver:
|
class DataResolver:
|
||||||
def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0):
|
def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0, seed=555):
|
||||||
self.aspects = aspects
|
self.aspects = aspects
|
||||||
self.flip_p = flip_p
|
self.flip_p = flip_p
|
||||||
self.events = []
|
self.events = []
|
||||||
|
@ -141,6 +142,7 @@ class DirectoryResolver(DataResolver):
|
||||||
items = []
|
items = []
|
||||||
multipliers = {}
|
multipliers = {}
|
||||||
skip_folders = []
|
skip_folders = []
|
||||||
|
randomizer = random.Random(self.seed)
|
||||||
|
|
||||||
for pathname in tqdm.tqdm(image_paths):
|
for pathname in tqdm.tqdm(image_paths):
|
||||||
current_dir = os.path.dirname(pathname)
|
current_dir = os.path.dirname(pathname)
|
||||||
|
@ -164,6 +166,16 @@ class DirectoryResolver(DataResolver):
|
||||||
caption = ImageCaption.resolve(pathname)
|
caption = ImageCaption.resolve(pathname)
|
||||||
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
|
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
|
||||||
|
|
||||||
|
cur_file_multiplier = multipliers[current_dir]
|
||||||
|
|
||||||
|
while cur_file_multiplier >= 1.0:
|
||||||
|
items.append(item)
|
||||||
|
cur_file_multiplier -= 1
|
||||||
|
|
||||||
|
if cur_file_multiplier > 0:
|
||||||
|
if randomizer.random() < cur_file_multiplier:
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
if item:
|
if item:
|
||||||
items.append(item)
|
items.append(item)
|
||||||
return items
|
return items
|
||||||
|
@ -210,17 +222,17 @@ def strategy(data_root: str):
|
||||||
raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.")
|
raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.")
|
||||||
|
|
||||||
|
|
||||||
def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0) -> typing.Tuple[list[ImageTrainItem], list[Event]]:
|
def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, seed) -> typing.Tuple[list[ImageTrainItem], list[Event]]:
|
||||||
"""
|
"""
|
||||||
:param data_root: Directory or JSON file.
|
:param data_root: Directory or JSON file.
|
||||||
:param aspects: The list of aspect ratios to use
|
:param aspects: The list of aspect ratios to use
|
||||||
:param flip_p: The probability of flipping the image
|
:param flip_p: The probability of flipping the image
|
||||||
"""
|
"""
|
||||||
if os.path.isfile(path) and path.endswith('.json'):
|
if os.path.isfile(path) and path.endswith('.json'):
|
||||||
resolver = JSONResolver(aspects, flip_p)
|
resolver = JSONResolver(aspects, flip_p, seed)
|
||||||
|
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
resolver = DirectoryResolver(aspects, flip_p)
|
resolver = DirectoryResolver(aspects, flip_p, seed)
|
||||||
|
|
||||||
if not resolver:
|
if not resolver:
|
||||||
raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.")
|
raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.")
|
||||||
|
@ -229,7 +241,7 @@ def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0) -> typing
|
||||||
events = resolver.events
|
events = resolver.events
|
||||||
return items, events
|
return items, events
|
||||||
|
|
||||||
def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0) -> typing.Tuple[list[ImageTrainItem], list[Event]]:
|
def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, seed=555) -> typing.Tuple[list[ImageTrainItem], list[Event]]:
|
||||||
"""
|
"""
|
||||||
Resolve the training data from the value.
|
Resolve the training data from the value.
|
||||||
:param value: The value to resolve, either a dict or a string.
|
:param value: The value to resolve, either a dict or a string.
|
||||||
|
@ -244,12 +256,12 @@ def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=
|
||||||
match resolver:
|
match resolver:
|
||||||
case 'directory' | 'json':
|
case 'directory' | 'json':
|
||||||
path = value.get('path', None)
|
path = value.get('path', None)
|
||||||
return resolve_root(path, aspects, flip_p)
|
return resolve_root(path, aspects, flip_p, seed)
|
||||||
case 'multi':
|
case 'multi':
|
||||||
resolved_items = []
|
resolved_items = []
|
||||||
resolved_events = []
|
resolved_events = []
|
||||||
for resolver in value.get('resolvers', []):
|
for resolver in value.get('resolvers', []):
|
||||||
items, events = resolve(resolver, aspects, flip_p)
|
items, events = resolve(resolver, aspects, flip_p, seed)
|
||||||
resolved_items.extend(items)
|
resolved_items.extend(items)
|
||||||
resolved_events.extend(events)
|
resolved_events.extend(events)
|
||||||
return resolved_items, resolved_events
|
return resolved_items, resolved_events
|
||||||
|
|
28
train.py
28
train.py
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
Copyright [2022] Victor C Hall
|
Copyright [2022-2023] Victor C Hall
|
||||||
|
|
||||||
Licensed under the GNU Affero General Public License;
|
Licensed under the GNU Affero General Public License;
|
||||||
You may not use this code except in compliance with the License.
|
You may not use this code except in compliance with the License.
|
||||||
|
@ -343,7 +343,7 @@ def main(args):
|
||||||
logging.info(f" * Saving SD model to {sd_ckpt_full}")
|
logging.info(f" * Saving SD model to {sd_ckpt_full}")
|
||||||
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
|
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
|
||||||
|
|
||||||
if yaml_name:
|
if yaml_name and yaml_name != "v1-inference.yaml":
|
||||||
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
|
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
|
||||||
logging.info(f" * Saving yaml to {yaml_save_path}")
|
logging.info(f" * Saving yaml to {yaml_save_path}")
|
||||||
shutil.copyfile(yaml_name, yaml_save_path)
|
shutil.copyfile(yaml_name, yaml_save_path)
|
||||||
|
@ -367,7 +367,6 @@ def main(args):
|
||||||
safety_checker=None, # save vram
|
safety_checker=None, # save vram
|
||||||
requires_safety_checker=None, # avoid nag
|
requires_safety_checker=None, # avoid nag
|
||||||
feature_extractor=None, # must be none of no safety checker
|
feature_extractor=None, # must be none of no safety checker
|
||||||
disable_tqdm=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
@ -410,6 +409,8 @@ def main(args):
|
||||||
generates samples at different cfg scales and saves them to disk
|
generates samples at different cfg scales and saves them to disk
|
||||||
"""
|
"""
|
||||||
logging.info(f"Generating samples gs:{gs}, for {prompts}")
|
logging.info(f"Generating samples gs:{gs}, for {prompts}")
|
||||||
|
pipe.set_progress_bar_config(disable=True)
|
||||||
|
|
||||||
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
||||||
gen = torch.Generator(device=device).manual_seed(seed)
|
gen = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
|
||||||
|
@ -588,7 +589,7 @@ def main(args):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
print(f" {Fore.LIGHTGREEN_EX}** Welcome to EveryDream trainer 2.0!**{Style.RESET_ALL}")
|
print(f" {Fore.LIGHTGREEN_EX}** Welcome to EveryDream trainer 2.0!**{Style.RESET_ALL}")
|
||||||
print(f" (C) 2022 Victor C Hall This program is licensed under AGPL 3.0 https://www.gnu.org/licenses/agpl-3.0.en.html")
|
print(f" (C) 2022-2023 Victor C Hall This program is licensed under AGPL 3.0 https://www.gnu.org/licenses/agpl-3.0.en.html")
|
||||||
print()
|
print()
|
||||||
print("** Trainer Starting **")
|
print("** Trainer Starting **")
|
||||||
|
|
||||||
|
@ -691,6 +692,24 @@ def main(args):
|
||||||
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# # dummy batch to pin memory to avoid fragmentation in torch, uses square aspect which is maximum bytes size per aspects.py
|
||||||
|
# pixel_values = torch.randn_like(torch.zeros([args.batch_size, 3, args.resolution, args.resolution]))
|
||||||
|
# pixel_values = pixel_values.to(unet.device)
|
||||||
|
# with autocast(enabled=args.amp):
|
||||||
|
# latents = vae.encode(pixel_values, return_dict=False)
|
||||||
|
# latents = latents[0].sample() * 0.18215
|
||||||
|
# noise = torch.randn_like(latents)
|
||||||
|
# bsz = latents.shape[0]
|
||||||
|
# timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||||
|
# timesteps = timesteps.long()
|
||||||
|
# noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
# cuda_caption = torch.linspace(100,177, steps=77, dtype=int).to(text_encoder.device)
|
||||||
|
# encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True).last_hidden_state
|
||||||
|
# with autocast(enabled=args.amp):
|
||||||
|
# model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
# # discard the grads, just want to pin memory
|
||||||
|
# optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
for epoch in range(args.max_epochs):
|
for epoch in range(args.max_epochs):
|
||||||
loss_epoch = []
|
loss_epoch = []
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
|
@ -801,7 +820,6 @@ def main(args):
|
||||||
if (global_step + 1) % args.sample_steps == 0:
|
if (global_step + 1) % args.sample_steps == 0:
|
||||||
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
|
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
|
||||||
pipe = pipe.to(device)
|
pipe = pipe.to(device)
|
||||||
#pipe.set_progress_bar_config(progress_bar=False)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1:
|
if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1:
|
||||||
|
|
Loading…
Reference in New Issue