torch2 stuff

This commit is contained in:
Victor Hall 2023-03-27 01:06:03 -04:00
parent 56256ab9ef
commit 036f7a5818
3 changed files with 71 additions and 12 deletions

33
torch2_install.md Normal file
View File

@ -0,0 +1,33 @@
Torch 2 support is extremely pre-alpha. It is not recommended for use.
There is a bunch of triton error spam in logs, I've tried to suppress the errors but it's not working yet.
Clone a new copy, the VENV is not compatible and git will not manage the environment
`git clone https://github.com/victorchall/EveryDream2trainer ed2torch2`
`cd ed2torch2`
`git checkout torch2`
run normal install (`windows_setup.cmd` or build the docker container)
Try running training.
If your hyperparametr/grad scaler goes to a tiny or massive value its not working. I sugget using validation to make sure its actually working.
If you have problems, might try to install the latest xformers wheel from the github actions:
Download the Xformers 3.10 cu118 wheel for your system
https://github.com/facebookresearch/xformers/actions/runs/4501451442
linux: ubuntu-22.04-py3.10-torch2.0.0+cu118
https://github.com/facebookresearch/xformers/suites/11760994158/artifacts/613483176
win: windows-2019-py3.10-torch2.0.0+cu118
https://github.com/facebookresearch/xformers/suites/11760994158/artifacts/613483194
Save the .whl file to your everydream2trainer folder
activate your venv and pip install the wheel file
`pip install xformers-0.0.17.dev484-cp310-cp310-win_amd64.whl`

View File

@ -122,7 +122,23 @@ def setup_local_logger(args):
datefmt="%m/%d/%Y %I:%M:%S %p",
)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.addFilter(lambda msg: "Palette images with Transparency expressed in bytes" in msg.getMessage())
#A matching Triton is not available
console_handler.addFilter(lambda msg: "Palette images with Transparency expressed in bytes" not in msg.getMessage())
console_handler.addFilter(lambda msg: "No module named 'triton'" not in msg.getMessage())
console_handler.addFilter(lambda msg: "A matching Triton is not available" not in msg.getMessage())
console_handler.addFilter(lambda msg: "None of the inputs have requires_grad=True" not in msg.getMessage())
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
logging.getLogger("pytorch").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
logging.getLogger("torch").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
logging.getLogger().addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
logging.getLogger("PIL.Image").addFilter(lambda msg: "Palette images with Transparency expressed in bytes" not in msg.getMessage())
logging.getLogger("PIL").addFilter(lambda msg: "Palette images with Transparency expressed in bytes" not in msg.getMessage())
logging.getLogger("pil").addFilter(lambda msg: "Palette images with Transparency expressed in bytes" not in msg.getMessage())
logging.getLogger("pillow").addFilter(lambda msg: "Palette images with Transparency expressed in bytes" not in msg.getMessage())
#logging.getLogger().addFilter(lambda msg: "No module named 'triton'" not in msg.getMessage())
#logging.getLogger().addFilter(lambda msg: "A matching Triton is not available" not in msg.getMessage())
logging.getLogger().addFilter(lambda msg: "None of the inputs have requires_grad=True" not in msg.getMessage())
logging.getLogger().addHandler(console_handler)
import warnings
warnings.filterwarnings("ignore", message="UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images")
@ -476,12 +492,16 @@ def main(args):
unet.enable_xformers_memory_efficient_attention()
logging.info("Enabled xformers")
except Exception as ex:
logging.warning("failed to load xformers, using attention slicing instead")
logging.warning("failed to load xformers, falling back to attention slicing instead")
unet.set_attention_slice("auto")
pass
else:
logging.info("xformers cannot be enabled, using attention slicing instead")
unet.set_attention_slice("auto")
else:
logging.info("xformers disabled, using attention slicing instead")
unet.set_attention_slice("auto")
#exit()
vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32)
unet = unet.to(device, dtype=torch.float32)
@ -794,7 +814,7 @@ def main(args):
del inference_pipe
gc.collect()
torch.cuda.empty_cache()
#torch.cuda.empty_cache()
# Pre-train validation to establish a starting point on the loss graph
if validator:
@ -876,7 +896,7 @@ def main(args):
log_writer.add_scalar(tag="hyperparamater/grad scale", scalar_value=scaler.get_scale(), global_step=global_step)
log_writer.add_scalar(tag="performance/images per second", scalar_value=avg, global_step=global_step)
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
torch.cuda.empty_cache()
#torch.cuda.empty_cache()
if (global_step + 1) % sample_generator.sample_steps == 0:
generate_samples(global_step=global_step, batch=batch)

View File

@ -3,24 +3,30 @@ call "venv\Scripts\activate.bat"
echo should be in venv here
cd .
python -m pip install --upgrade pip
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url "https://download.pytorch.org/whl/cu116"
pip install transformers==4.27.1
pip install diffusers[torch]==0.13.0
pip install pynvml==11.4.1
::torch-2.0.0+cu118-cp310-cp310-linux_x86_64.whl
::torch-2.0.0+cu118-cp310-cp310-win_amd64.whl
::torchvision-0.15.1+cu118-cp310-cp310-linux_x86_64.whl
::torchvision-0.15.1+cu118-cp310-cp310-win_amd64.whl
pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 --extra-index-url "https://download.pytorch.org/whl/cu118"
pip install transformers==4.27.3
pip install diffusers[torch]==0.14.0
pip install pynvml==11.5.0
pip install bitsandbytes==0.35.0
git clone https://github.com/DeXtmL/bitsandbytes-win-prebuilt tmp/bnb_cache
pip install ftfy==6.1.1
pip install aiohttp==3.8.3
pip install tensorboard>=2.11.0
pip install aiohttp==3.8.4
pip install tensorboard>=2.12.0
pip install protobuf==3.20.1
pip install wandb==0.14.0
pip install pyre-extensions==0.0.23
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
::pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
::pip install "xformers-0.0.15.dev0+affe4da.d20221212-cp38-cp38-win_amd64.whl" --force-reinstall
pip install --pre xformers
pip install pytorch-lightning==1.6.5
pip install OmegaConf==2.2.3
pip install numpy==1.23.5
pip install keyboard
pip install lion-pytorch
python utils/patch_bnb.py
python utils/get_yamls.py