add gpu id support
This commit is contained in:
parent
422bd7a413
commit
051116a7d9
|
@ -21,6 +21,7 @@ import random
|
|||
from data.image_train_item import ImageTrainItem
|
||||
import data.aspects as aspects
|
||||
from colorama import Fore, Style
|
||||
import zipfile
|
||||
|
||||
class DataLoaderMultiAspect():
|
||||
"""
|
||||
|
@ -40,12 +41,24 @@ class DataLoaderMultiAspect():
|
|||
logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
|
||||
logging.info(" Preloading images...")
|
||||
|
||||
self.unzip_all(data_root)
|
||||
|
||||
self.__recurse_data_root(self=self, recurse_root=data_root)
|
||||
random.Random(seed).shuffle(self.image_paths)
|
||||
prepared_train_data = self.__prescan_images(self.image_paths, flip_p) # ImageTrainItem[]
|
||||
self.image_caption_pairs = self.__bucketize_images(prepared_train_data, batch_size=batch_size, debug_level=debug_level)
|
||||
|
||||
#if debug_level > 0: print(f" * DLMA Example: {self.image_caption_pairs[0]} images")
|
||||
def unzip_all(self, path):
|
||||
#recursively unzip all files in path
|
||||
try:
|
||||
for root, dirs, files in os.walk(path):
|
||||
for file in files:
|
||||
if file.endswith('.zip'):
|
||||
logging.info(f"Unzipping {file}")
|
||||
with zipfile.ZipFile(path, 'r') as zip_ref:
|
||||
zip_ref.extractall(path)
|
||||
except Exception as e:
|
||||
logging.error(f"Error unzipping files {e}")
|
||||
|
||||
def get_all_images(self):
|
||||
return self.image_caption_pairs
|
||||
|
|
34
train.py
34
train.py
|
@ -56,6 +56,19 @@ _GRAD_ACCUM_STEPS = 1 # future use...
|
|||
_SIGTERM_EXIT_CODE = 130
|
||||
_VERY_LARGE_NUMBER = 1e9
|
||||
|
||||
# def is_notebook() -> bool:
|
||||
# try:
|
||||
# from IPython import get_ipython
|
||||
# shell = get_ipython().__class__.__name__
|
||||
# if shell == 'ZMQInteractiveShell':
|
||||
# return True # Jupyter notebook or qtconsole
|
||||
# elif shell == 'TerminalInteractiveShell':
|
||||
# return False # Terminal running IPython
|
||||
# else:
|
||||
# return False # Other type (?)
|
||||
# except NameError:
|
||||
# return False # Probably standard Python interpreter
|
||||
|
||||
def clean_filename(filename):
|
||||
"""
|
||||
removes all non-alphanumeric characters from a string so it is safe to use as a filename
|
||||
|
@ -163,10 +176,12 @@ def main(args):
|
|||
Main entry point
|
||||
"""
|
||||
log_time = setup_local_logger(args)
|
||||
#notebook = is_notebook()
|
||||
|
||||
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
||||
set_seed(seed)
|
||||
gpu = GPU()
|
||||
device = torch.device(f"cuda:{args.gpuid}")
|
||||
torch.backends.cudnn.benchmark = False
|
||||
args.clip_skip = max(min(3, args.clip_skip), 0)
|
||||
|
||||
|
@ -197,7 +212,7 @@ def main(args):
|
|||
args.lr = args.lr * (total_batch_size**0.5)
|
||||
logging.info(f"{Fore.CYAN} * Scaling learning rate {tmp_lr} by {total_batch_size**0.5}, new value: {args.lr}{Style.RESET_ALL}")
|
||||
|
||||
log_folder = os.path.join(args.logdir, f"{args.project_name}{log_time}")
|
||||
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
||||
logging.info(f"Logging to {log_folder}")
|
||||
if not os.path.exists(log_folder):
|
||||
os.makedirs(log_folder)
|
||||
|
@ -368,9 +383,9 @@ def main(args):
|
|||
default_lr = 2e-6
|
||||
curr_lr = args.lr if args.lr is not None else default_lr
|
||||
|
||||
vae = vae.to(torch.device("cuda"), dtype=torch.float32 if not args.amp else torch.float16)
|
||||
unet = unet.to(torch.device("cuda"), dtype=torch.float32 if not args.amp else torch.float16)
|
||||
text_encoder = text_encoder.to(torch.device("cuda"), dtype=torch.float32 if not args.amp else torch.float16)
|
||||
vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16)
|
||||
unet = unet.to(device, dtype=torch.float32 if not args.amp else torch.float16)
|
||||
text_encoder = text_encoder.to(device, dtype=torch.float32 if not args.amp else torch.float16)
|
||||
|
||||
if args.disable_textenc_training:
|
||||
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
||||
|
@ -550,10 +565,10 @@ def main(args):
|
|||
#logging.info(f" {Fore.GREEN}total_batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{total_batch_size}")
|
||||
logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}")
|
||||
|
||||
epoch_pbar = tqdm(range(args.max_epochs), position=0)
|
||||
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
|
||||
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
|
||||
|
||||
steps_pbar = tqdm(range(epoch_len), position=1)
|
||||
steps_pbar = tqdm(range(epoch_len), position=1, leave=True)
|
||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
|
||||
|
||||
epoch_times = []
|
||||
|
@ -566,7 +581,7 @@ def main(args):
|
|||
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
loss = torch.tensor(0.0, device=torch.device("cuda"), dtype=torch.float32)
|
||||
loss = torch.tensor(0.0, device=device, dtype=torch.float32)
|
||||
|
||||
try:
|
||||
for epoch in range(args.max_epochs):
|
||||
|
@ -659,7 +674,7 @@ def main(args):
|
|||
if (global_step + 1) % args.sample_steps == 0:
|
||||
#(unet, text_encoder, tokenizer, scheduler):
|
||||
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, vae=vae)
|
||||
pipe = pipe.to(torch.device("cuda"))
|
||||
pipe = pipe.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1:
|
||||
|
@ -732,7 +747,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
|
||||
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
|
||||
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
|
||||
argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3])
|
||||
argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4])
|
||||
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
|
||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
|
||||
|
@ -744,6 +759,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
|
||||
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
||||
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5")
|
||||
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)")
|
||||
args = argparser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
|
|
@ -3,7 +3,7 @@ call "venv\Scripts\activate.bat"
|
|||
echo should be in venv here
|
||||
cd .
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url "https://download.pytorch.org/whl/cu116"
|
||||
pip install transformers==4.25.1
|
||||
pip install diffusers[torch]==0.10.2
|
||||
pip install pynvml==11.4.1
|
||||
|
|
Loading…
Reference in New Issue