update to torch2, xformers 20, bnb 381

This commit is contained in:
Victor Hall 2023-05-30 22:15:02 -04:00
parent 803cadfd53
commit 5c98cdee70
5 changed files with 125 additions and 80 deletions

View File

@ -30,14 +30,14 @@ class EveryDreamBatch(Dataset):
data_loader: `DataLoaderMultiAspect` object
debug_level: 0=none, 1=print drops due to unfilled batches on aspect ratio buckets, 2=debug info per image, 3=save crops to disk for inspection
conditional_dropout: probability of dropping the caption for a given image
crop_jitter: number of pixels to jitter the crop by, only for non-square images
crop_jitter: percent of maximum cropping for crop jitter (ex 0.02 is two percent)
seed: random seed
"""
def __init__(self,
data_loader: DataLoaderMultiAspect,
debug_level=0,
conditional_dropout=0.02,
crop_jitter=20,
crop_jitter=0.02,
seed=555,
tokenizer=None,
retain_contrast=False,
@ -129,7 +129,7 @@ class EveryDreamBatch(Dataset):
example = {}
save = debug_level > 2
image_train_tmp = image_train_item.hydrate(crop=False, save=save, crop_jitter=self.crop_jitter)
image_train_tmp = image_train_item.hydrate(save=save, crop_jitter=self.crop_jitter)
example["image"] = image_train_tmp.image.copy() # hack for now to avoid memory leak
image_train_tmp.image = None # hack for now to avoid memory leak

View File

@ -27,9 +27,6 @@ import PIL.ImageOps as ImageOps
import numpy as np
from torchvision import transforms
_RANDOM_TRIM = 0.04
OptionalImageCaption = typing.Optional['ImageCaption']
class ImageCaption:
@ -143,7 +140,7 @@ class ImageTrainItem:
else:
self.image = image
self.image_size = image.size
self.target_size = None
#self.target_size = None
self.is_undersized = False
self.error = None
@ -165,80 +162,108 @@ class ImageTrainItem:
pass
return image
def hydrate(self, crop=False, save=False, crop_jitter=20):
def _percent_random_crop(self, image, crop_jitter=0.02):
"""
randomly crops the image by a percentage of the image size on each of the four sides
"""
width, height = image.size
max_crop_pixels = min(width, height) * crop_jitter
left_crop_pixels = random.uniform(0, max_crop_pixels)
right_crop_pixels = random.uniform(0, max_crop_pixels)
top_crop_pixels = random.uniform(0, max_crop_pixels)
bottom_crop_pixels = random.uniform(0, max_crop_pixels)
# Calculate the cropping coordinates
left = left_crop_pixels
right = width - right_crop_pixels
top = top_crop_pixels
bottom = height - bottom_crop_pixels
#print(f"\n *** jitter l: {left}, t: {top}, r: {right}, b: {bottom}, orig w: {width}, h: {height}, max_crop_pixels: {max_crop_pixels}")
# Crop the image
cropped_image = image.crop((left, top, right, bottom))
cropped_width = width - int(left_crop_pixels + right_crop_pixels)
cropped_height = height - int(top_crop_pixels + bottom_crop_pixels)
cropped_aspect_ratio = cropped_width / cropped_height
# Resize the cropped image to maintain square pixels
if cropped_aspect_ratio > 1:
new_width = cropped_width
new_height = int(cropped_width / cropped_aspect_ratio)
else:
new_width = int(cropped_height * cropped_aspect_ratio)
new_height = cropped_height
#print(f" *** postsquarefix new w: {new_width}, h: {new_height}")
cropped_image = cropped_image.resize((new_width, new_height))
return cropped_image
def _debug_save_image(self, image, folder=""):
base_name = os.path.basename(self.pathname)
target_dir = os.path.join('test/output', folder)
target_file = os.path.join(target_dir, base_name)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
try:
#print(f"saving to test/output: {os.path.join('test/output', folder, base_name)}")
image.save(target_file)
except Exception as e:
print(f"error for debug saving image of {self.pathname}: {e}")
pass
def _trim_to_aspect(self, image, target_wh):
width, height = image.size
target_aspect = target_wh[0] / target_wh[1] # 0.60
image_aspect = width / height # 0.5865
#self._debug_save_image(image, "precrop")
if image_aspect > target_aspect:
target_width = int(height * target_aspect)
overwidth = width - target_width
l = random.normalvariate(overwidth/2, overwidth/2)
l = max(0, l)
l = min(l, overwidth)
r = width - int(overwidth) - l
image = image.crop((l, 0, r, height))
elif target_aspect > image_aspect:
target_height = int(width / target_aspect)
overheight = height - target_height
image = image.crop((0, int(overheight/2), width, height-int(overheight/2)))
def hydrate(self, save=False, crop_jitter=0.02):
"""
crop: hard center crop to 512x512
save: save the cropped image to disk, for manual inspection of resize/crop
crop_jitter: randomly shift cropp by N pixels when using multiple aspect ratios to improve training quality
"""
# print(self.pathname, self.image)
try:
# try:
# if not hasattr(self, 'image'):
self.image = self.load_image()
image = self.load_image()
width, height = self.image.size
if crop:
cropped_img = self.__autocrop(self.image)
self.image = cropped_img.resize((512, 512), resample=PIL.Image.BICUBIC)
else:
width, height = self.image.size
jitter_amount = random.randint(0, crop_jitter)
#print(f"** jittering: {self.pathname}")
if self.target_wh[0] == self.target_wh[1]:
if width > height:
left = random.randint(0, width - height)
self.image = self.image.crop((left, 0, height + left, height))
width = height
elif height > width:
top = random.randint(0, height - width)
self.image = self.image.crop((0, top, width, width + top))
height = width
elif width > self.target_wh[0]:
slice = min(int(self.target_wh[0] * _RANDOM_TRIM), width - self.target_wh[0])
slicew_ratio = random.random()
left = int(slice * slicew_ratio)
right = width - int(slice * (1 - slicew_ratio))
sliceh_ratio = random.random()
top = int(slice * sliceh_ratio)
bottom = height - int(slice * (1 - sliceh_ratio))
width, height = image.size
self.image = self.image.crop((left, top, right, bottom))
else:
image_aspect = width / height
target_aspect = self.target_wh[0] / self.target_wh[1]
if image_aspect > target_aspect:
new_width = int(height * target_aspect)
jitter_amount = max(min(jitter_amount, int(abs(width - new_width) / 2)), 0)
left = jitter_amount
right = left + new_width
self.image = self.image.crop((left, 0, right, height))
else:
new_height = int(width / target_aspect)
jitter_amount = max(min(jitter_amount, int(abs(height - new_height) / 2)), 0)
top = jitter_amount
bottom = top + new_height
self.image = self.image.crop((0, top, width, bottom))
self.image = self.image.resize(self.target_wh, resample=PIL.Image.BICUBIC)
img_jitter = min((width-self.target_wh[0])/self.target_wh[0], (height-self.target_wh[1])/self.target_wh[1])
img_jitter = min(img_jitter, crop_jitter)
img_jitter = max(img_jitter, 0.0)
if img_jitter > 0.0:
image = self._percent_random_crop(image, img_jitter)
self.image = self.flip(self.image)
except Exception as e:
logging.error(f"Fatal Error loading image: {self.pathname}:")
logging.error(e)
exit()
self._trim_to_aspect(image, self.target_wh)
if type(self.image) is not np.ndarray:
if save:
base_name = os.path.basename(self.pathname)
if not os.path.exists("test/output"):
os.makedirs("test/output")
self.image.save(f"test/output/{base_name}")
self.image = image.resize(self.target_wh)
self.image = np.array(self.image).astype(np.uint8)
# self.image = (self.image / 127.5 - 1.0).astype(np.float32)
# print(self.image.shape)
self.image = self.flip(self.image)
# self._debug_save_image(self.image, "final")
self.image = np.array(self.image).astype(np.uint8)
return self
def __compute_target_width_height(self):
@ -250,7 +275,7 @@ class ImageTrainItem:
image_aspect = width / height
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
self.is_undersized = (width * height) < (target_wh[0] * target_wh[1])
self.is_undersized = (width * height) < (target_wh[0]*1.02 * target_wh[1]*1.02)
self.target_wh = target_wh
except Exception as e:
self.error = e

20
requirements.txt Normal file
View File

@ -0,0 +1,20 @@
torch==2.0.1
torchvision==0.15.2
transformers==4.29.2
diffusers[torch]==0.14.0
pynvml==11.4.1
bitsandbytes==0.38.1
ftfy==6.1.1
aiohttp==3.8.4
tensorboard>=2.11.0
protobuf==3.20.1
pyre-extensions==0.0.29
xformers==0.0.20
pytorch-lightning==1.6.5
OmegaConf==2.2.3
numpy==1.23.5
lion-pytorch
compel~=1.1.3
OmegaConf==2.2.3
numpy==1.23.5
wandb

View File

@ -357,6 +357,10 @@ def main(args):
"""
Main entry point
"""
if os.name == 'nt':
print(" * Windows detected, disabling Triton")
os.environ['XFORMERS_FORCE_DISABLE_TRITON'] = "1"
log_time = setup_local_logger(args)
args = setup_args(args)
print(f" Args:")

View File

@ -3,27 +3,23 @@ call "venv\Scripts\activate.bat"
echo should be in venv here
cd .
python -m pip install --upgrade pip
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url "https://download.pytorch.org/whl/cu116"
pip install -U transformers==4.27.1
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url "https://download.pytorch.org/whl/cu118"
pip install -U transformers==4.29.2
pip install -U diffusers[torch]==0.14.0
pip install pynvml==11.4.1
pip install -U https://github.com/victorchall/everydream-whls/raw/main/bitsandbytes-0.38.1-py2.py3-none-any.whl
git clone https://github.com/DeXtmL/bitsandbytes-win-prebuilt tmp/bnb_cache
pip install ftfy==6.1.1
pip install aiohttp==3.8.3
pip install aiohttp==3.8.4
pip install tensorboard>=2.11.0
pip install protobuf==3.20.1
pip install wandb==0.14.0
pip install pyre-extensions==0.0.23
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
::pip install "xformers-0.0.15.dev0+affe4da.d20221212-cp38-cp38-win_amd64.whl" --force-reinstall
pip install wandb==0.15.3
pip install pyre-extensions==0.0.29
pip install -U xformers==0.0.20
pip install pytorch-lightning==1.6.5
pip install OmegaConf==2.2.3
pip install numpy==1.23.5
pip install keyboard
pip install lion-pytorch
pip install compel~=1.1.3
python utils/patch_bnb.py
python utils/get_yamls.py
GOTO :eof