update to torch2, xformers 20, bnb 381
This commit is contained in:
parent
803cadfd53
commit
5c98cdee70
|
@ -30,14 +30,14 @@ class EveryDreamBatch(Dataset):
|
|||
data_loader: `DataLoaderMultiAspect` object
|
||||
debug_level: 0=none, 1=print drops due to unfilled batches on aspect ratio buckets, 2=debug info per image, 3=save crops to disk for inspection
|
||||
conditional_dropout: probability of dropping the caption for a given image
|
||||
crop_jitter: number of pixels to jitter the crop by, only for non-square images
|
||||
crop_jitter: percent of maximum cropping for crop jitter (ex 0.02 is two percent)
|
||||
seed: random seed
|
||||
"""
|
||||
def __init__(self,
|
||||
data_loader: DataLoaderMultiAspect,
|
||||
debug_level=0,
|
||||
conditional_dropout=0.02,
|
||||
crop_jitter=20,
|
||||
crop_jitter=0.02,
|
||||
seed=555,
|
||||
tokenizer=None,
|
||||
retain_contrast=False,
|
||||
|
@ -129,7 +129,7 @@ class EveryDreamBatch(Dataset):
|
|||
example = {}
|
||||
save = debug_level > 2
|
||||
|
||||
image_train_tmp = image_train_item.hydrate(crop=False, save=save, crop_jitter=self.crop_jitter)
|
||||
image_train_tmp = image_train_item.hydrate(save=save, crop_jitter=self.crop_jitter)
|
||||
|
||||
example["image"] = image_train_tmp.image.copy() # hack for now to avoid memory leak
|
||||
image_train_tmp.image = None # hack for now to avoid memory leak
|
||||
|
|
|
@ -27,9 +27,6 @@ import PIL.ImageOps as ImageOps
|
|||
import numpy as np
|
||||
from torchvision import transforms
|
||||
|
||||
_RANDOM_TRIM = 0.04
|
||||
|
||||
|
||||
OptionalImageCaption = typing.Optional['ImageCaption']
|
||||
|
||||
class ImageCaption:
|
||||
|
@ -143,7 +140,7 @@ class ImageTrainItem:
|
|||
else:
|
||||
self.image = image
|
||||
self.image_size = image.size
|
||||
self.target_size = None
|
||||
#self.target_size = None
|
||||
|
||||
self.is_undersized = False
|
||||
self.error = None
|
||||
|
@ -165,80 +162,108 @@ class ImageTrainItem:
|
|||
pass
|
||||
return image
|
||||
|
||||
def hydrate(self, crop=False, save=False, crop_jitter=20):
|
||||
def _percent_random_crop(self, image, crop_jitter=0.02):
|
||||
"""
|
||||
randomly crops the image by a percentage of the image size on each of the four sides
|
||||
"""
|
||||
width, height = image.size
|
||||
max_crop_pixels = min(width, height) * crop_jitter
|
||||
|
||||
left_crop_pixels = random.uniform(0, max_crop_pixels)
|
||||
right_crop_pixels = random.uniform(0, max_crop_pixels)
|
||||
top_crop_pixels = random.uniform(0, max_crop_pixels)
|
||||
bottom_crop_pixels = random.uniform(0, max_crop_pixels)
|
||||
|
||||
# Calculate the cropping coordinates
|
||||
left = left_crop_pixels
|
||||
right = width - right_crop_pixels
|
||||
top = top_crop_pixels
|
||||
bottom = height - bottom_crop_pixels
|
||||
#print(f"\n *** jitter l: {left}, t: {top}, r: {right}, b: {bottom}, orig w: {width}, h: {height}, max_crop_pixels: {max_crop_pixels}")
|
||||
|
||||
# Crop the image
|
||||
cropped_image = image.crop((left, top, right, bottom))
|
||||
|
||||
cropped_width = width - int(left_crop_pixels + right_crop_pixels)
|
||||
cropped_height = height - int(top_crop_pixels + bottom_crop_pixels)
|
||||
|
||||
cropped_aspect_ratio = cropped_width / cropped_height
|
||||
|
||||
# Resize the cropped image to maintain square pixels
|
||||
if cropped_aspect_ratio > 1:
|
||||
new_width = cropped_width
|
||||
new_height = int(cropped_width / cropped_aspect_ratio)
|
||||
else:
|
||||
new_width = int(cropped_height * cropped_aspect_ratio)
|
||||
new_height = cropped_height
|
||||
|
||||
#print(f" *** postsquarefix new w: {new_width}, h: {new_height}")
|
||||
cropped_image = cropped_image.resize((new_width, new_height))
|
||||
|
||||
return cropped_image
|
||||
|
||||
def _debug_save_image(self, image, folder=""):
|
||||
base_name = os.path.basename(self.pathname)
|
||||
target_dir = os.path.join('test/output', folder)
|
||||
target_file = os.path.join(target_dir, base_name)
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
os.makedirs(target_dir)
|
||||
|
||||
try:
|
||||
#print(f"saving to test/output: {os.path.join('test/output', folder, base_name)}")
|
||||
image.save(target_file)
|
||||
except Exception as e:
|
||||
print(f"error for debug saving image of {self.pathname}: {e}")
|
||||
pass
|
||||
|
||||
def _trim_to_aspect(self, image, target_wh):
|
||||
width, height = image.size
|
||||
target_aspect = target_wh[0] / target_wh[1] # 0.60
|
||||
image_aspect = width / height # 0.5865
|
||||
#self._debug_save_image(image, "precrop")
|
||||
if image_aspect > target_aspect:
|
||||
target_width = int(height * target_aspect)
|
||||
overwidth = width - target_width
|
||||
l = random.normalvariate(overwidth/2, overwidth/2)
|
||||
l = max(0, l)
|
||||
l = min(l, overwidth)
|
||||
r = width - int(overwidth) - l
|
||||
image = image.crop((l, 0, r, height))
|
||||
elif target_aspect > image_aspect:
|
||||
target_height = int(width / target_aspect)
|
||||
overheight = height - target_height
|
||||
image = image.crop((0, int(overheight/2), width, height-int(overheight/2)))
|
||||
|
||||
def hydrate(self, save=False, crop_jitter=0.02):
|
||||
"""
|
||||
crop: hard center crop to 512x512
|
||||
save: save the cropped image to disk, for manual inspection of resize/crop
|
||||
crop_jitter: randomly shift cropp by N pixels when using multiple aspect ratios to improve training quality
|
||||
"""
|
||||
# print(self.pathname, self.image)
|
||||
try:
|
||||
# try:
|
||||
# if not hasattr(self, 'image'):
|
||||
self.image = self.load_image()
|
||||
image = self.load_image()
|
||||
|
||||
width, height = self.image.size
|
||||
if crop:
|
||||
cropped_img = self.__autocrop(self.image)
|
||||
self.image = cropped_img.resize((512, 512), resample=PIL.Image.BICUBIC)
|
||||
else:
|
||||
width, height = self.image.size
|
||||
jitter_amount = random.randint(0, crop_jitter)
|
||||
#print(f"** jittering: {self.pathname}")
|
||||
|
||||
if self.target_wh[0] == self.target_wh[1]:
|
||||
if width > height:
|
||||
left = random.randint(0, width - height)
|
||||
self.image = self.image.crop((left, 0, height + left, height))
|
||||
width = height
|
||||
elif height > width:
|
||||
top = random.randint(0, height - width)
|
||||
self.image = self.image.crop((0, top, width, width + top))
|
||||
height = width
|
||||
elif width > self.target_wh[0]:
|
||||
slice = min(int(self.target_wh[0] * _RANDOM_TRIM), width - self.target_wh[0])
|
||||
slicew_ratio = random.random()
|
||||
left = int(slice * slicew_ratio)
|
||||
right = width - int(slice * (1 - slicew_ratio))
|
||||
sliceh_ratio = random.random()
|
||||
top = int(slice * sliceh_ratio)
|
||||
bottom = height - int(slice * (1 - sliceh_ratio))
|
||||
width, height = image.size
|
||||
|
||||
self.image = self.image.crop((left, top, right, bottom))
|
||||
else:
|
||||
image_aspect = width / height
|
||||
target_aspect = self.target_wh[0] / self.target_wh[1]
|
||||
if image_aspect > target_aspect:
|
||||
new_width = int(height * target_aspect)
|
||||
jitter_amount = max(min(jitter_amount, int(abs(width - new_width) / 2)), 0)
|
||||
left = jitter_amount
|
||||
right = left + new_width
|
||||
self.image = self.image.crop((left, 0, right, height))
|
||||
else:
|
||||
new_height = int(width / target_aspect)
|
||||
jitter_amount = max(min(jitter_amount, int(abs(height - new_height) / 2)), 0)
|
||||
top = jitter_amount
|
||||
bottom = top + new_height
|
||||
self.image = self.image.crop((0, top, width, bottom))
|
||||
self.image = self.image.resize(self.target_wh, resample=PIL.Image.BICUBIC)
|
||||
img_jitter = min((width-self.target_wh[0])/self.target_wh[0], (height-self.target_wh[1])/self.target_wh[1])
|
||||
img_jitter = min(img_jitter, crop_jitter)
|
||||
img_jitter = max(img_jitter, 0.0)
|
||||
|
||||
if img_jitter > 0.0:
|
||||
image = self._percent_random_crop(image, img_jitter)
|
||||
|
||||
self.image = self.flip(self.image)
|
||||
except Exception as e:
|
||||
logging.error(f"Fatal Error loading image: {self.pathname}:")
|
||||
logging.error(e)
|
||||
exit()
|
||||
self._trim_to_aspect(image, self.target_wh)
|
||||
|
||||
if type(self.image) is not np.ndarray:
|
||||
if save:
|
||||
base_name = os.path.basename(self.pathname)
|
||||
if not os.path.exists("test/output"):
|
||||
os.makedirs("test/output")
|
||||
self.image.save(f"test/output/{base_name}")
|
||||
self.image = image.resize(self.target_wh)
|
||||
|
||||
self.image = np.array(self.image).astype(np.uint8)
|
||||
|
||||
# self.image = (self.image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
# print(self.image.shape)
|
||||
self.image = self.flip(self.image)
|
||||
# self._debug_save_image(self.image, "final")
|
||||
|
||||
self.image = np.array(self.image).astype(np.uint8)
|
||||
|
||||
return self
|
||||
|
||||
def __compute_target_width_height(self):
|
||||
|
@ -250,7 +275,7 @@ class ImageTrainItem:
|
|||
image_aspect = width / height
|
||||
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
|
||||
|
||||
self.is_undersized = (width * height) < (target_wh[0] * target_wh[1])
|
||||
self.is_undersized = (width * height) < (target_wh[0]*1.02 * target_wh[1]*1.02)
|
||||
self.target_wh = target_wh
|
||||
except Exception as e:
|
||||
self.error = e
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
torch==2.0.1
|
||||
torchvision==0.15.2
|
||||
transformers==4.29.2
|
||||
diffusers[torch]==0.14.0
|
||||
pynvml==11.4.1
|
||||
bitsandbytes==0.38.1
|
||||
ftfy==6.1.1
|
||||
aiohttp==3.8.4
|
||||
tensorboard>=2.11.0
|
||||
protobuf==3.20.1
|
||||
pyre-extensions==0.0.29
|
||||
xformers==0.0.20
|
||||
pytorch-lightning==1.6.5
|
||||
OmegaConf==2.2.3
|
||||
numpy==1.23.5
|
||||
lion-pytorch
|
||||
compel~=1.1.3
|
||||
OmegaConf==2.2.3
|
||||
numpy==1.23.5
|
||||
wandb
|
4
train.py
4
train.py
|
@ -357,6 +357,10 @@ def main(args):
|
|||
"""
|
||||
Main entry point
|
||||
"""
|
||||
if os.name == 'nt':
|
||||
print(" * Windows detected, disabling Triton")
|
||||
os.environ['XFORMERS_FORCE_DISABLE_TRITON'] = "1"
|
||||
|
||||
log_time = setup_local_logger(args)
|
||||
args = setup_args(args)
|
||||
print(f" Args:")
|
||||
|
|
|
@ -3,27 +3,23 @@ call "venv\Scripts\activate.bat"
|
|||
echo should be in venv here
|
||||
cd .
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url "https://download.pytorch.org/whl/cu116"
|
||||
pip install -U transformers==4.27.1
|
||||
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url "https://download.pytorch.org/whl/cu118"
|
||||
pip install -U transformers==4.29.2
|
||||
pip install -U diffusers[torch]==0.14.0
|
||||
pip install pynvml==11.4.1
|
||||
pip install -U https://github.com/victorchall/everydream-whls/raw/main/bitsandbytes-0.38.1-py2.py3-none-any.whl
|
||||
git clone https://github.com/DeXtmL/bitsandbytes-win-prebuilt tmp/bnb_cache
|
||||
pip install ftfy==6.1.1
|
||||
pip install aiohttp==3.8.3
|
||||
pip install aiohttp==3.8.4
|
||||
pip install tensorboard>=2.11.0
|
||||
pip install protobuf==3.20.1
|
||||
pip install wandb==0.14.0
|
||||
pip install pyre-extensions==0.0.23
|
||||
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
||||
::pip install "xformers-0.0.15.dev0+affe4da.d20221212-cp38-cp38-win_amd64.whl" --force-reinstall
|
||||
pip install wandb==0.15.3
|
||||
pip install pyre-extensions==0.0.29
|
||||
pip install -U xformers==0.0.20
|
||||
pip install pytorch-lightning==1.6.5
|
||||
pip install OmegaConf==2.2.3
|
||||
pip install numpy==1.23.5
|
||||
pip install keyboard
|
||||
pip install lion-pytorch
|
||||
pip install compel~=1.1.3
|
||||
python utils/patch_bnb.py
|
||||
python utils/get_yamls.py
|
||||
GOTO :eof
|
||||
|
||||
|
|
Loading…
Reference in New Issue