torch2 stuff
This commit is contained in:
parent
56256ab9ef
commit
036f7a5818
|
@ -0,0 +1,33 @@
|
|||
Torch 2 support is extremely pre-alpha. It is not recommended for use.
|
||||
|
||||
There is a bunch of triton error spam in logs, I've tried to suppress the errors but it's not working yet.
|
||||
|
||||
Clone a new copy, the VENV is not compatible and git will not manage the environment
|
||||
|
||||
`git clone https://github.com/victorchall/EveryDream2trainer ed2torch2`
|
||||
`cd ed2torch2`
|
||||
`git checkout torch2`
|
||||
|
||||
run normal install (`windows_setup.cmd` or build the docker container)
|
||||
|
||||
Try running training.
|
||||
|
||||
If your hyperparametr/grad scaler goes to a tiny or massive value its not working. I sugget using validation to make sure its actually working.
|
||||
|
||||
If you have problems, might try to install the latest xformers wheel from the github actions:
|
||||
|
||||
Download the Xformers 3.10 cu118 wheel for your system
|
||||
|
||||
https://github.com/facebookresearch/xformers/actions/runs/4501451442
|
||||
|
||||
linux: ubuntu-22.04-py3.10-torch2.0.0+cu118
|
||||
https://github.com/facebookresearch/xformers/suites/11760994158/artifacts/613483176
|
||||
|
||||
win: windows-2019-py3.10-torch2.0.0+cu118
|
||||
https://github.com/facebookresearch/xformers/suites/11760994158/artifacts/613483194
|
||||
|
||||
Save the .whl file to your everydream2trainer folder
|
||||
|
||||
activate your venv and pip install the wheel file
|
||||
|
||||
`pip install xformers-0.0.17.dev484-cp310-cp310-win_amd64.whl`
|
28
train.py
28
train.py
|
@ -122,7 +122,23 @@ def setup_local_logger(args):
|
|||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.addFilter(lambda msg: "Palette images with Transparency expressed in bytes" in msg.getMessage())
|
||||
#A matching Triton is not available
|
||||
console_handler.addFilter(lambda msg: "Palette images with Transparency expressed in bytes" not in msg.getMessage())
|
||||
console_handler.addFilter(lambda msg: "No module named 'triton'" not in msg.getMessage())
|
||||
console_handler.addFilter(lambda msg: "A matching Triton is not available" not in msg.getMessage())
|
||||
console_handler.addFilter(lambda msg: "None of the inputs have requires_grad=True" not in msg.getMessage())
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
logging.getLogger("pytorch").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
logging.getLogger("torch").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
logging.getLogger().addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
|
||||
logging.getLogger("PIL.Image").addFilter(lambda msg: "Palette images with Transparency expressed in bytes" not in msg.getMessage())
|
||||
logging.getLogger("PIL").addFilter(lambda msg: "Palette images with Transparency expressed in bytes" not in msg.getMessage())
|
||||
logging.getLogger("pil").addFilter(lambda msg: "Palette images with Transparency expressed in bytes" not in msg.getMessage())
|
||||
logging.getLogger("pillow").addFilter(lambda msg: "Palette images with Transparency expressed in bytes" not in msg.getMessage())
|
||||
#logging.getLogger().addFilter(lambda msg: "No module named 'triton'" not in msg.getMessage())
|
||||
#logging.getLogger().addFilter(lambda msg: "A matching Triton is not available" not in msg.getMessage())
|
||||
logging.getLogger().addFilter(lambda msg: "None of the inputs have requires_grad=True" not in msg.getMessage())
|
||||
logging.getLogger().addHandler(console_handler)
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", message="UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images")
|
||||
|
@ -476,12 +492,16 @@ def main(args):
|
|||
unet.enable_xformers_memory_efficient_attention()
|
||||
logging.info("Enabled xformers")
|
||||
except Exception as ex:
|
||||
logging.warning("failed to load xformers, using attention slicing instead")
|
||||
logging.warning("failed to load xformers, falling back to attention slicing instead")
|
||||
unet.set_attention_slice("auto")
|
||||
pass
|
||||
else:
|
||||
logging.info("xformers cannot be enabled, using attention slicing instead")
|
||||
unet.set_attention_slice("auto")
|
||||
else:
|
||||
logging.info("xformers disabled, using attention slicing instead")
|
||||
unet.set_attention_slice("auto")
|
||||
#exit()
|
||||
|
||||
vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32)
|
||||
unet = unet.to(device, dtype=torch.float32)
|
||||
|
@ -794,7 +814,7 @@ def main(args):
|
|||
|
||||
del inference_pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
#torch.cuda.empty_cache()
|
||||
|
||||
# Pre-train validation to establish a starting point on the loss graph
|
||||
if validator:
|
||||
|
@ -876,7 +896,7 @@ def main(args):
|
|||
log_writer.add_scalar(tag="hyperparamater/grad scale", scalar_value=scaler.get_scale(), global_step=global_step)
|
||||
log_writer.add_scalar(tag="performance/images per second", scalar_value=avg, global_step=global_step)
|
||||
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
|
||||
torch.cuda.empty_cache()
|
||||
#torch.cuda.empty_cache()
|
||||
|
||||
if (global_step + 1) % sample_generator.sample_steps == 0:
|
||||
generate_samples(global_step=global_step, batch=batch)
|
||||
|
|
|
@ -3,24 +3,30 @@ call "venv\Scripts\activate.bat"
|
|||
echo should be in venv here
|
||||
cd .
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url "https://download.pytorch.org/whl/cu116"
|
||||
pip install transformers==4.27.1
|
||||
pip install diffusers[torch]==0.13.0
|
||||
pip install pynvml==11.4.1
|
||||
|
||||
::torch-2.0.0+cu118-cp310-cp310-linux_x86_64.whl
|
||||
::torch-2.0.0+cu118-cp310-cp310-win_amd64.whl
|
||||
::torchvision-0.15.1+cu118-cp310-cp310-linux_x86_64.whl
|
||||
::torchvision-0.15.1+cu118-cp310-cp310-win_amd64.whl
|
||||
|
||||
pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 --extra-index-url "https://download.pytorch.org/whl/cu118"
|
||||
pip install transformers==4.27.3
|
||||
pip install diffusers[torch]==0.14.0
|
||||
pip install pynvml==11.5.0
|
||||
pip install bitsandbytes==0.35.0
|
||||
git clone https://github.com/DeXtmL/bitsandbytes-win-prebuilt tmp/bnb_cache
|
||||
pip install ftfy==6.1.1
|
||||
pip install aiohttp==3.8.3
|
||||
pip install tensorboard>=2.11.0
|
||||
pip install aiohttp==3.8.4
|
||||
pip install tensorboard>=2.12.0
|
||||
pip install protobuf==3.20.1
|
||||
pip install wandb==0.14.0
|
||||
pip install pyre-extensions==0.0.23
|
||||
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
||||
::pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
||||
::pip install "xformers-0.0.15.dev0+affe4da.d20221212-cp38-cp38-win_amd64.whl" --force-reinstall
|
||||
pip install --pre xformers
|
||||
pip install pytorch-lightning==1.6.5
|
||||
pip install OmegaConf==2.2.3
|
||||
pip install numpy==1.23.5
|
||||
pip install keyboard
|
||||
pip install lion-pytorch
|
||||
python utils/patch_bnb.py
|
||||
python utils/get_yamls.py
|
||||
|
|
Loading…
Reference in New Issue