Merge pull request #28 from AstraliteHeart/main

Misc cleanups from pony-diffusion test run
This commit is contained in:
Anthony Mercurio 2022-11-05 08:55:48 -07:00 committed by GitHub
commit 2321e22fc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 24 deletions

View File

@ -1,5 +1,9 @@
# Install bitsandbytes:
# `nvcc --version` to get CUDA version.
# `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA.
# Example Usage:
# torchrun --nproc_per_node=2 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
# Single GPU: torchrun --nproc_per_node=1 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
# Multiple GPUs: torchrun --nproc_per_node=N trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
import argparse
import socket
@ -535,7 +539,6 @@ def main():
beta_end=0.012,
beta_schedule='scaled_linear',
num_train_timesteps=1000,
tensor_format='pt'
)
# load dataset

View File

@ -11,7 +11,8 @@ streamlit>=0.73.1
einops==0.3.0
torch-fidelity==0.3.0
transformers==4.19.2
torchmetrics==0.6.0
diffusers==0.7.1
torchmetrics==0.7.0
kornia==0.6
gradio
git+https://github.com/illeatmyhat/taming-transformers.git@master#egg=taming-transformers
@ -20,3 +21,4 @@ git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion
webdataset
wandb
fairscale
pynvml==11.4.1