Misc minor fixes

This commit is contained in:
reijerh 2023-11-08 23:49:38 +01:00
parent 6ea721887c
commit 7de666ec2d
8 changed files with 59 additions and 44 deletions

View File

@ -125,7 +125,8 @@
" 'colorama',\n",
" 'keyboard',\n",
" 'lion-pytorch',\n",
" 'safetensors'\n",
" 'safetensors',\n",
" 'torchsde'\n",
"]\n",
"\n",
"print(colored(0, 255, 0, 'Installing packages...'))\n",

View File

@ -279,9 +279,6 @@ class ImageTrainItem:
return image
def hydrate(self, save=False, crop_jitter=0.02):
"""
save: save the cropped image to disk, for manual inspection of resize/crop
"""
image = self.load_image()
width, height = image.size
@ -298,8 +295,9 @@ class ImageTrainItem:
self.image = image.resize(self.target_wh)
self.image = self.flip(self.image)
# Remove comment here to view image cropping outputs
#self._debug_save_image(self.image, "final")
if save:
self._debug_save_image(self.image, "final")
self.image = np.array(self.image).astype(np.uint8)
@ -313,6 +311,7 @@ class ImageTrainItem:
height, width = image.size
else:
width, height = image.size
image_aspect = width / height
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))

View File

@ -17,3 +17,4 @@ tensorboard==2.12.0
wandb
safetensors
prodigyopt
torchsde

View File

@ -223,23 +223,23 @@ class EveryDreamOptimizer():
if args.lr_warmup_steps is None:
# set warmup to 2% of decay, if decay was autoset to 150% of max epochs then warmup will end up about 3% of max epochs
args.lr_warmup_steps = int(args.lr_decay_steps / 50)
if args.lr is not None:
# override for legacy support reasons
base_config["lr"] = args.lr
base_config["optimizer"] = base_config.get("optimizer", None) or "adamw8bit"
base_config["lr_warmup_steps"] = base_config.get("lr_warmup_steps", None) or args.lr_warmup_steps
base_config["lr_decay_steps"] = base_config.get("lr_decay_steps", None) or args.lr_decay_steps
base_config["lr_scheduler"] = base_config.get("lr_scheduler", None) or args.lr_scheduler
base_config["lr_warmup_steps"] = base_config.get("lr_warmup_steps", None) or args.lr_warmup_steps
base_config["lr_warmup_steps"] = base_config.get("lr_warmup_steps", args.lr_warmup_steps)
base_config["lr_decay_steps"] = base_config.get("lr_decay_steps", None) or args.lr_decay_steps
base_config["lr_scheduler"] = base_config.get("lr_scheduler", None) or args.lr_scheduler
te_config["lr"] = te_config.get("lr", None) or base_config["lr"]
te_config["optimizer"] = te_config.get("optimizer", None) or base_config["optimizer"]
te_config["lr_scheduler"] = te_config.get("lr_scheduler", None) or base_config["lr_scheduler"]
te_config["lr_warmup_steps"] = te_config.get("lr_warmup_steps", None) or base_config["lr_warmup_steps"]
te_config["lr_warmup_steps"] = te_config.get("lr_warmup_steps", base_config["lr_warmup_steps"])
te_config["lr_decay_steps"] = te_config.get("lr_decay_steps", None) or base_config["lr_decay_steps"]
te_config["weight_decay"] = te_config.get("weight_decay", None) or base_config["weight_decay"]
te_config["betas"] = te_config.get("betas", None) or base_config["betas"]

View File

@ -22,3 +22,4 @@ wandb
colorama
safetensors
open-flamingo==2.0.0
torchsde

View File

@ -14,36 +14,36 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import argparse
import os
import random
import argparse
from datetime import datetime
from torch import no_grad
from torch.cuda.amp import autocast
import torch
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerAncestralDiscreteScheduler
#from diffusers.models import AttentionBlock
from diffusers.optimization import get_scheduler
from diffusers.utils.import_utils import is_xformers_available
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DPMSolverSDEScheduler
from torch.cuda.amp import autocast
from transformers import CLIPTextModel, CLIPTokenizer
def __generate_sample(pipe: StableDiffusionPipeline, prompt : str, cfg: float, resolution: int, gen, steps: int = 30):
# from diffusers.models import AttentionBlock
def __generate_sample(pipe: StableDiffusionPipeline, prompt: str, cfg: float, height: int, width: int, gen,
steps: int = 30, batch_size: int = 1):
"""
generates a single sample at a given cfg scale and saves it to disk
"""
"""
with autocast():
image = pipe(prompt,
num_inference_steps=steps,
num_images_per_prompt=1,
guidance_scale=cfg,
generator=gen,
height=resolution,
width=resolution,
).images[0]
return image
images = pipe(prompt,
num_inference_steps=steps,
num_images_per_prompt=batch_size,
guidance_scale=cfg,
generator=gen,
height=height,
width=width,
).images
return images
def __create_inference_pipe(unet, text_encoder, tokenizer, scheduler, vae):
"""
@ -55,14 +55,18 @@ def __create_inference_pipe(unet, text_encoder, tokenizer, scheduler, vae):
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
)
return pipe
def main(args):
# Create output folder if it doesn't exist
os.makedirs('output', exist_ok=True)
text_encoder = CLIPTextModel.from_pretrained(args.diffusers_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.diffusers_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.diffusers_path, subfolder="unet")
@ -79,23 +83,32 @@ def main(args):
pipe = __create_inference_pipe(unet, text_encoder, tokenizer, sample_scheduler, vae)
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
gen = torch.Generator(device="cuda").manual_seed(seed)
for _ in range(args.batch_count):
seed = args.seed if args.seed != -1 else random.randint(0, 2 ** 30)
gen = torch.Generator(device="cuda").manual_seed(seed)
img = __generate_sample(pipe, args.prompt, args.cfg_scale, args.resolution, gen=gen, steps=args.steps)
images = __generate_sample(pipe, args.prompt, args.cfg_scale, args.height, args.width, gen=gen,
steps=args.steps,
batch_size=args.batch_size)
img.save(f"img_{args.prompt[0:100].replace(' ', '_')}_cfg_{args.cfg_scale}.png")
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
for i, img in enumerate(images):
img.save(
f"output/img_{args.prompt[0:210].replace(' ', '_')}_cfg_{args.cfg_scale}_{i}_{seed}_{timestamp}.png")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--diffusers_path', type=str, default=None, required=True, help='path to diffusers model (from logs)')
parser.add_argument('--diffusers_path', type=str, default=None, required=True,
help='path to diffusers model (from logs)')
parser.add_argument('--prompt', type=str, required=True, help='prompt to use')
parser.add_argument('--resolution', type=int, default=512, help='resolution (def: 512)')
parser.add_argument('--height', type=int, default=512, help='height (def: 512)')
parser.add_argument('--width', type=int, default=512, help='width (def: 512)')
parser.add_argument('--seed', type=int, default=-1, help='seed, use -1 for random (def: -1)')
parser.add_argument('--steps', type=int, default=50, help='inference, denoising steps (def: 50)')
parser.add_argument('--cfg_scale', type=int, default=7.5, help='unconditional guidance scale (def: 7.5)')
parser.add_argument('--cfg_scale', type=float, default=7.5, help='unconditional guidance scale (def: 7.5)')
parser.add_argument('--batch_size', type=int, default=1, help='batch size (def: 1)')
parser.add_argument('--batch_count', type=int, default=1, help='batch count (def: 1)')
args = parser.parse_args()
main(args)
main(args)

View File

@ -390,7 +390,6 @@ def setup_args(args):
if args.grad_accum > 1:
logging.info(f"{Fore.CYAN} Batch size: {args.batch_size}, grad accum: {args.grad_accum}, 'effective' batch size: {args.batch_size * args.grad_accum}{Style.RESET_ALL}")
total_batch_size = args.batch_size * args.grad_accum
if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir):
os.makedirs(args.save_ckpt_dir)

View File

@ -24,6 +24,7 @@ pip install dadaptation
pip install safetensors
pip install open-flamingo==2.0.0
pip install prodigyopt
pip install torchsde
python utils/get_yamls.py
GOTO :eof