Merge pull request #28 from AstraliteHeart/main
Misc cleanups from pony-diffusion test run
This commit is contained in:
commit
2321e22fc1
|
@ -1,5 +1,9 @@
|
||||||
|
# Install bitsandbytes:
|
||||||
|
# `nvcc --version` to get CUDA version.
|
||||||
|
# `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA.
|
||||||
# Example Usage:
|
# Example Usage:
|
||||||
# torchrun --nproc_per_node=2 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
|
# Single GPU: torchrun --nproc_per_node=1 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
|
||||||
|
# Multiple GPUs: torchrun --nproc_per_node=N trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import socket
|
import socket
|
||||||
|
@ -535,7 +539,6 @@ def main():
|
||||||
beta_end=0.012,
|
beta_end=0.012,
|
||||||
beta_schedule='scaled_linear',
|
beta_schedule='scaled_linear',
|
||||||
num_train_timesteps=1000,
|
num_train_timesteps=1000,
|
||||||
tensor_format='pt'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# load dataset
|
# load dataset
|
||||||
|
|
|
@ -11,7 +11,8 @@ streamlit>=0.73.1
|
||||||
einops==0.3.0
|
einops==0.3.0
|
||||||
torch-fidelity==0.3.0
|
torch-fidelity==0.3.0
|
||||||
transformers==4.19.2
|
transformers==4.19.2
|
||||||
torchmetrics==0.6.0
|
diffusers==0.7.1
|
||||||
|
torchmetrics==0.7.0
|
||||||
kornia==0.6
|
kornia==0.6
|
||||||
gradio
|
gradio
|
||||||
git+https://github.com/illeatmyhat/taming-transformers.git@master#egg=taming-transformers
|
git+https://github.com/illeatmyhat/taming-transformers.git@master#egg=taming-transformers
|
||||||
|
@ -20,3 +21,4 @@ git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion
|
||||||
webdataset
|
webdataset
|
||||||
wandb
|
wandb
|
||||||
fairscale
|
fairscale
|
||||||
|
pynvml==11.4.1
|
Loading…
Reference in New Issue