diff --git a/models/vision/ddim/example.py b/models/vision/ddim/example.py index 52f75b62..cd9abfb7 100755 --- a/models/vision/ddim/example.py +++ b/models/vision/ddim/example.py @@ -1,10 +1,13 @@ #!/usr/bin/env python3 import os import pathlib -from modeling_ddim import DDIM -import PIL.Image + import numpy as np +import PIL.Image +from modeling_ddim import DDIM + + model_ids = ["ddim-celeba-hq", "ddim-lsun-church", "ddim-lsun-bedroom"] for model_id in model_ids: diff --git a/models/vision/ddim/modeling_ddim.py b/models/vision/ddim/modeling_ddim.py index 1e1ffea0..2ff8dacc 100644 --- a/models/vision/ddim/modeling_ddim.py +++ b/models/vision/ddim/modeling_ddim.py @@ -14,13 +14,13 @@ # limitations under the License. -from diffusers import DiffusionPipeline -import tqdm import torch +import tqdm +from diffusers import DiffusionPipeline + class DDIM(DiffusionPipeline): - def __init__(self, unet, noise_scheduler): super().__init__() self.register_modules(unet=unet, noise_scheduler=noise_scheduler) @@ -34,12 +34,16 @@ class DDIM(DiffusionPipeline): inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) self.unet.to(torch_device) - image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) + image = self.noise_scheduler.sample_noise( + (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), + device=torch_device, + generator=generator, + ) for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): # get actual t and t-1 train_step = inference_step_times[t] - prev_train_step = inference_step_times[t - 1] if t > 0 else - 1 + prev_train_step = inference_step_times[t - 1] if t > 0 else -1 # compute alphas alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step) @@ -50,8 +54,14 @@ class DDIM(DiffusionPipeline): beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt() # compute relevant coefficients - coeff_1 = (alpha_prod_t_prev - alpha_prod_t).sqrt() * alpha_prod_t_prev_rsqrt * beta_prod_t_prev_sqrt / beta_prod_t_sqrt * eta - coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1 ** 2).sqrt() + coeff_1 = ( + (alpha_prod_t_prev - alpha_prod_t).sqrt() + * alpha_prod_t_prev_rsqrt + * beta_prod_t_prev_sqrt + / beta_prod_t_sqrt + * eta + ) + coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1**2).sqrt() # model forward with torch.no_grad(): diff --git a/models/vision/ddim/run_inference.py b/models/vision/ddim/run_inference.py index 59ed5865..babe5d96 100755 --- a/models/vision/ddim/run_inference.py +++ b/models/vision/ddim/run_inference.py @@ -1,9 +1,11 @@ #!/usr/bin/env python3 # !pip install diffusers -from modeling_ddim import DDIM -import PIL.Image import numpy as np +import PIL.Image +from modeling_ddim import DDIM + + model_id = "fusing/ddpm-cifar10" model_id = "fusing/ddpm-lsun-bedroom" diff --git a/models/vision/ddpm/example.py b/models/vision/ddpm/example.py index c0864074..5eb5a769 100755 --- a/models/vision/ddpm/example.py +++ b/models/vision/ddpm/example.py @@ -1,11 +1,25 @@ #!/usr/bin/env python3 import os import pathlib -from modeling_ddpm import DDPM -import PIL.Image + import numpy as np -model_ids = ["ddpm-lsun-cat", "ddpm-lsun-cat-ema", "ddpm-lsun-church-ema", "ddpm-lsun-church", "ddpm-lsun-bedroom", "ddpm-lsun-bedroom-ema", "ddpm-cifar10-ema", "ddpm-cifar10", "ddpm-celeba-hq", "ddpm-celeba-hq-ema"] +import PIL.Image +from modeling_ddpm import DDPM + + +model_ids = [ + "ddpm-lsun-cat", + "ddpm-lsun-cat-ema", + "ddpm-lsun-church-ema", + "ddpm-lsun-church", + "ddpm-lsun-bedroom", + "ddpm-lsun-bedroom-ema", + "ddpm-cifar10-ema", + "ddpm-cifar10", + "ddpm-celeba-hq", + "ddpm-celeba-hq-ema", +] for model_id in model_ids: path = os.path.join("/home/patrick/images/hf", model_id) diff --git a/models/vision/ddpm/modeling_ddpm.py b/models/vision/ddpm/modeling_ddpm.py index f84ab452..584a6145 100644 --- a/models/vision/ddpm/modeling_ddpm.py +++ b/models/vision/ddpm/modeling_ddpm.py @@ -14,13 +14,13 @@ # limitations under the License. -from diffusers import DiffusionPipeline -import tqdm import torch +import tqdm +from diffusers import DiffusionPipeline + class DDPM(DiffusionPipeline): - def __init__(self, unet, noise_scheduler): super().__init__() self.register_modules(unet=unet, noise_scheduler=noise_scheduler) @@ -31,13 +31,25 @@ class DDPM(DiffusionPipeline): self.unet.to(torch_device) # 1. Sample gaussian noise - image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) + image = self.noise_scheduler.sample_noise( + (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), + device=torch_device, + generator=generator, + ) for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)): # i) define coefficients for time step t clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t)) clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1) - image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t)) - clipped_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t)) + image_coeff = ( + (1 - self.noise_scheduler.get_alpha_prod(t - 1)) + * torch.sqrt(self.noise_scheduler.get_alpha(t)) + / (1 - self.noise_scheduler.get_alpha_prod(t)) + ) + clipped_coeff = ( + torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) + * self.noise_scheduler.get_beta(t) + / (1 - self.noise_scheduler.get_alpha_prod(t)) + ) # ii) predict noise residual with torch.no_grad(): @@ -50,7 +62,9 @@ class DDPM(DiffusionPipeline): prev_image = clipped_coeff * pred_mean + image_coeff * image # iv) sample variance - prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator) + prev_variance = self.noise_scheduler.sample_variance( + t, prev_image.shape, device=torch_device, generator=generator + ) # v) sample x_{t-1} ~ N(prev_image, prev_variance) sampled_prev_image = prev_image + prev_variance diff --git a/models/vision/glide/convert_weights.py b/models/vision/glide/convert_weights.py index 5bcc68ff..c9ef9dfc 100644 --- a/models/vision/glide/convert_weights.py +++ b/models/vision/glide/convert_weights.py @@ -1,7 +1,13 @@ import torch from torch import nn -from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, GLIDETextToImageUNetModel, GLIDESuperResUNetModel +from diffusers import ( + ClassifierFreeGuidanceScheduler, + CLIPTextModel, + GlideDDIMScheduler, + GLIDESuperResUNetModel, + GLIDETextToImageUNetModel, +) from modeling_glide import GLIDE from transformers import CLIPTextConfig, GPT2Tokenizer @@ -22,7 +28,9 @@ config = CLIPTextConfig( use_padding_embeddings=True, ) model = CLIPTextModel(config).eval() -tokenizer = GPT2Tokenizer("./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>") +tokenizer = GPT2Tokenizer( + "./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>" +) hf_encoder = model.text_model @@ -97,10 +105,13 @@ superres_model.load_state_dict(ups_state_dict, strict=False) upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear") -glide = GLIDE(text_unet=text2im_model, text_noise_scheduler=text_scheduler, text_encoder=model, tokenizer=tokenizer, - upscale_unet=superres_model, upscale_noise_scheduler=upscale_scheduler) +glide = GLIDE( + text_unet=text2im_model, + text_noise_scheduler=text_scheduler, + text_encoder=model, + tokenizer=tokenizer, + upscale_unet=superres_model, + upscale_noise_scheduler=upscale_scheduler, +) glide.save_pretrained("./glide-base") - - - diff --git a/models/vision/glide/modeling_glide.py b/models/vision/glide/modeling_glide.py index 9ccb6625..d2a15ffd 100644 --- a/models/vision/glide/modeling_glide.py +++ b/models/vision/glide/modeling_glide.py @@ -18,7 +18,14 @@ import numpy as np import torch import tqdm -from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel +from diffusers import ( + ClassifierFreeGuidanceScheduler, + CLIPTextModel, + DiffusionPipeline, + GlideDDIMScheduler, + GLIDESuperResUNetModel, + GLIDETextToImageUNetModel, +) from transformers import GPT2Tokenizer @@ -46,12 +53,16 @@ class GLIDE(DiffusionPipeline): text_encoder: CLIPTextModel, tokenizer: GPT2Tokenizer, upscale_unet: GLIDESuperResUNetModel, - upscale_noise_scheduler: GlideDDIMScheduler + upscale_noise_scheduler: GlideDDIMScheduler, ): super().__init__() self.register_modules( - text_unet=text_unet, text_noise_scheduler=text_noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer, - upscale_unet=upscale_unet, upscale_noise_scheduler=upscale_noise_scheduler + text_unet=text_unet, + text_noise_scheduler=text_noise_scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + upscale_unet=upscale_unet, + upscale_noise_scheduler=upscale_noise_scheduler, ) def q_posterior_mean_variance(self, scheduler, x_start, x_t, t): @@ -67,9 +78,7 @@ class GLIDE(DiffusionPipeline): + _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = _extract_into_tensor( - scheduler.posterior_log_variance_clipped, t, x_t.shape - ) + posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape) assert ( posterior_mean.shape[0] == posterior_variance.shape[0] @@ -190,19 +199,30 @@ class GLIDE(DiffusionPipeline): # A value of 1.0 is sharper, but sometimes results in grainy artifacts. upsample_temp = 0.997 - image = self.upscale_noise_scheduler.sample_noise( - (batch_size, 3, 256, 256), device=torch_device, generator=generator - ) * upsample_temp + image = ( + self.upscale_noise_scheduler.sample_noise( + (batch_size, 3, 256, 256), device=torch_device, generator=generator + ) + * upsample_temp + ) num_timesteps = len(self.upscale_noise_scheduler) - for t in tqdm.tqdm(reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)): + for t in tqdm.tqdm( + reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler) + ): # i) define coefficients for time step t clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t)) clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1) - image_coeff = (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt( - self.upscale_noise_scheduler.get_alpha(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t)) - clipped_coeff = torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * self.upscale_noise_scheduler.get_beta( - t) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t)) + image_coeff = ( + (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) + * torch.sqrt(self.upscale_noise_scheduler.get_alpha(t)) + / (1 - self.upscale_noise_scheduler.get_alpha_prod(t)) + ) + clipped_coeff = ( + torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) + * self.upscale_noise_scheduler.get_beta(t) + / (1 - self.upscale_noise_scheduler.get_alpha_prod(t)) + ) # ii) predict noise residual time_input = torch.tensor([t] * image.shape[0], device=torch_device) @@ -216,8 +236,9 @@ class GLIDE(DiffusionPipeline): prev_image = clipped_coeff * pred_mean + image_coeff * image # iv) sample variance - prev_variance = self.upscale_noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, - generator=generator) + prev_variance = self.upscale_noise_scheduler.sample_variance( + t, prev_image.shape, device=torch_device, generator=generator + ) # v) sample x_{t-1} ~ N(prev_image, prev_variance) sampled_prev_image = prev_image + prev_variance diff --git a/models/vision/glide/run_glide.py b/models/vision/glide/run_glide.py index c4351291..6a00b12a 100644 --- a/models/vision/glide/run_glide.py +++ b/models/vision/glide/run_glide.py @@ -1,6 +1,8 @@ import torch -from diffusers import DiffusionPipeline + import PIL.Image +from diffusers import DiffusionPipeline + generator = torch.Generator() generator = generator.manual_seed(0) @@ -14,8 +16,8 @@ pipeline = DiffusionPipeline.from_pretrained(model_id) img = pipeline("a clip art of a hugging face", generator) # process image to PIL -img = ((img + 1)*127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy() +img = ((img + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy() image_pil = PIL.Image.fromarray(img) # save image -image_pil.save("test.png") \ No newline at end of file +image_pil.save("test.png") diff --git a/setup.py b/setup.py index 96d1f309..da605822 100644 --- a/setup.py +++ b/setup.py @@ -84,6 +84,7 @@ _deps = [ "isort>=5.5.4", "numpy", "pytest", + "regex!=2019.12.17", "requests", "torch>=1.4", "torchvision", @@ -168,6 +169,7 @@ install_requires = [ deps["filelock"], deps["huggingface-hub"], deps["numpy"], + deps["regex"], deps["requests"], deps["torch"], deps["torchvision"], diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index c9285df3..850f059c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -7,7 +7,7 @@ __version__ = "0.0.1" from .modeling_utils import ModelMixin from .models.clip_text_transformer import CLIPTextModel from .models.unet import UNetModel -from .models.unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel +from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .models.unet_ldm import UNetLDMModel from .models.vqvae import VQModel from .pipeline_utils import DiffusionPipeline diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index fc7e04cc..51c27e33 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -23,13 +23,13 @@ import os import re from typing import Any, Dict, Tuple, Union -from requests import HTTPError from huggingface_hub import hf_hub_download +from requests import HTTPError - +from . import __version__ from .utils import ( - HUGGINGFACE_CO_RESOLVE_ENDPOINT, DIFFUSERS_CACHE, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError, @@ -37,9 +37,6 @@ from .utils import ( ) -from . import __version__ - - logger = logging.get_logger(__name__) _re_configuration_file = re.compile(r"config\.(.*)\.json") @@ -95,9 +92,7 @@ class ConfigMixin: @classmethod def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): - config_dict = cls.get_config_dict( - pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs - ) + config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) @@ -157,16 +152,16 @@ class ConfigMixin: except RepositoryNotFoundError: raise EnvironmentError( - f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on " - "'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having " - "permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass " - "`use_auth_token=True`." + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed" + " on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token" + " having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and" + " pass `use_auth_token=True`." ) except RevisionNotFoundError: raise EnvironmentError( - f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this " - f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for " - "available revisions." + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for" + " this model name. Check the model page at" + f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." ) except EntryNotFoundError: raise EnvironmentError( @@ -174,14 +169,16 @@ class ConfigMixin: ) except HTTPError as err: raise EnvironmentError( - f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" + "There was a specific connection error when trying to load" + f" {pretrained_model_name_or_path}:\n{err}" ) except ValueError: raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in" - f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory" - f" containing a {cls.config_name} file.\nCheckout your internet connection or see how to run the" - " library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to" + " run the library in offline mode at" + " 'https://huggingface.co/docs/diffusers/installation#offline-mode'." ) except EnvironmentError: raise EnvironmentError( @@ -195,9 +192,7 @@ class ConfigMixin: # Load config dict config_dict = cls._dict_from_json_file(config_file) except (json.JSONDecodeError, UnicodeDecodeError): - raise EnvironmentError( - f"It looks like the config file at '{config_file}' is not a valid JSON file." - ) + raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") return config_dict diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 6b552e0d..b972b9a0 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -3,29 +3,15 @@ # 2. run `make deps_table_update`` deps = { "Pillow": "Pillow", - "accelerate": "accelerate>=0.9.0", "black": "black~=22.0,>=22.3", - "codecarbon": "codecarbon==1.2.0", - "dataclasses": "dataclasses", - "datasets": "datasets", - "GitPython": "GitPython<3.1.19", - "hf-doc-builder": "hf-doc-builder>=0.3.0", - "huggingface-hub": "huggingface-hub>=0.1.0,<1.0", - "importlib_metadata": "importlib_metadata", + "filelock": "filelock", + "flake8": "flake8>=3.8.3", + "huggingface-hub": "huggingface-hub", "isort": "isort>=5.5.4", - "numpy": "numpy>=1.17", + "numpy": "numpy", "pytest": "pytest", - "pytest-timeout": "pytest-timeout", - "pytest-xdist": "pytest-xdist", - "python": "python>=3.7.0", "regex": "regex!=2019.12.17", "requests": "requests", - "sagemaker": "sagemaker>=2.31.0", - "tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.13", "torch": "torch>=1.4", - "torchaudio": "torchaudio", - "tqdm": "tqdm>=4.27", - "unidic": "unidic>=1.0.2", - "unidic_lite": "unidic_lite>=1.0.7", - "uvicorn": "uvicorn", + "torchvision": "torchvision", } diff --git a/src/diffusers/dynamic_modules_utils.py b/src/diffusers/dynamic_modules_utils.py index c1ca34ee..0ebf916e 100644 --- a/src/diffusers/dynamic_modules_utils.py +++ b/src/diffusers/dynamic_modules_utils.py @@ -23,7 +23,8 @@ from pathlib import Path from typing import Dict, Optional, Union from huggingface_hub import cached_download -from .utils import HF_MODULES_CACHE, DIFFUSERS_DYNAMIC_MODULE_NAME, logging + +from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 20870e34..dd3c6e9e 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -20,8 +20,8 @@ from typing import Callable, List, Optional, Tuple, Union import torch from torch import Tensor, device -from requests import HTTPError from huggingface_hub import hf_hub_download +from requests import HTTPError from .utils import ( CONFIG_NAME, @@ -379,10 +379,13 @@ class ModelMixin(torch.nn.Module): f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." ) except EntryNotFoundError: - raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}.") + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}." + ) except HTTPError as err: raise EnvironmentError( - f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" + "There was a specific connection error when trying to load" + f" {pretrained_model_name_or_path}:\n{err}" ) except ValueError: raise EnvironmentError( diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index c6405ed8..fb312f06 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -18,6 +18,6 @@ from .clip_text_transformer import CLIPTextModel from .unet import UNetModel -from .unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel +from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .unet_ldm import UNetLDMModel -from .vqvae import VQModel \ No newline at end of file +from .vqvae import VQModel diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index 470d420c..56211711 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -25,8 +25,8 @@ from torch.cuda.amp import GradScaler, autocast from torch.optim import Adam from torch.utils import data -from torchvision import transforms, utils from PIL import Image +from torchvision import transforms, utils from tqdm import tqdm from ..configuration_utils import ConfigMixin @@ -335,19 +335,22 @@ class UNetModel(ModelMixin, ConfigMixin): # dataset classes + class Dataset(data.Dataset): - def __init__(self, folder, image_size, exts=['jpg', 'jpeg', 'png']): + def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]): super().__init__() self.folder = folder self.image_size = image_size - self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] + self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")] - self.transform = transforms.Compose([ - transforms.Resize(image_size), - transforms.RandomHorizontalFlip(), - transforms.CenterCrop(image_size), - transforms.ToTensor() - ]) + self.transform = transforms.Compose( + [ + transforms.Resize(image_size), + transforms.RandomHorizontalFlip(), + transforms.CenterCrop(image_size), + transforms.ToTensor(), + ] + ) def __len__(self): return len(self.paths) @@ -359,7 +362,7 @@ class Dataset(data.Dataset): # trainer class -class EMA(): +class EMA: def __init__(self, beta): super().__init__() self.beta = beta diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 24ef868b..d017b992 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -647,24 +647,24 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel): """ def __init__( - self, - in_channels=3, - model_channels=192, - out_channels=6, - num_res_blocks=3, - attention_resolutions=(2, 4, 8), - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - use_checkpoint=False, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - transformer_dim=512 + self, + in_channels=3, + model_channels=192, + out_channels=6, + num_res_blocks=3, + attention_resolutions=(2, 4, 8), + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + transformer_dim=512, ): super().__init__( in_channels=in_channels, @@ -683,7 +683,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel): num_heads_upsample=num_heads_upsample, use_scale_shift_norm=use_scale_shift_norm, resblock_updown=resblock_updown, - transformer_dim=transformer_dim + transformer_dim=transformer_dim, ) self.register( in_channels=in_channels, @@ -702,7 +702,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel): num_heads_upsample=num_heads_upsample, use_scale_shift_norm=use_scale_shift_norm, resblock_updown=resblock_updown, - transformer_dim=transformer_dim + transformer_dim=transformer_dim, ) self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4) @@ -737,23 +737,23 @@ class GLIDESuperResUNetModel(GLIDEUNetModel): """ def __init__( - self, - in_channels=3, - model_channels=192, - out_channels=6, - num_res_blocks=3, - attention_resolutions=(2, 4, 8), - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - use_checkpoint=False, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, + self, + in_channels=3, + model_channels=192, + out_channels=6, + num_res_blocks=3, + attention_resolutions=(2, 4, 8), + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, ): super().__init__( in_channels=in_channels, @@ -809,4 +809,4 @@ class GLIDESuperResUNetModel(GLIDEUNetModel): h = torch.cat([h, hs.pop()], dim=1) h = module(h, emb) - return self.out(h) \ No newline at end of file + return self.out(h) diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index c9cd7a36..dcef3909 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -1,14 +1,15 @@ -from inspect import isfunction -from abc import abstractmethod import math +from abc import abstractmethod +from inspect import isfunction import numpy as np import torch import torch.nn as nn import torch.nn.functional as F + try: - from einops import repeat, rearrange + from einops import rearrange, repeat except: print("Einops is not installed") pass @@ -16,12 +17,13 @@ except: from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin + def exists(val): return val is not None def uniq(arr): - return{el: True for el in arr}.keys() + return {el: True for el in arr}.keys() def default(val, d): @@ -53,20 +55,13 @@ class GEGLU(nn.Module): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) - self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) - ) + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) def forward(self, x): return self.net(x) @@ -90,17 +85,17 @@ class LinearAttention(nn.Module): super().__init__() self.heads = heads hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) - q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) - k = k.softmax(dim=-1) - context = torch.einsum('bhdn,bhen->bhde', k, v) - out = torch.einsum('bhde,bhdn->bhen', context, q) - out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) return self.to_out(out) @@ -110,26 +105,10 @@ class SpatialSelfAttention(nn.Module): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x @@ -139,41 +118,38 @@ class SpatialSelfAttention(nn.Module): v = self.v(h_) # compute attention - b,c,h,w = q.shape - q = rearrange(q, 'b c h w -> b (h w) c') - k = rearrange(k, 'b c h w -> b c (h w)') - w_ = torch.einsum('bij,bjk->bik', q, k) + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) - w_ = w_ * (int(c)**(-0.5)) + w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values - v = rearrange(v, 'b c h w -> b c (h w)') - w_ = rearrange(w_, 'b i j -> b j i') - h_ = torch.einsum('bij,bjk->bik', v, w_) - h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) h_ = self.proj_out(h_) - return x+h_ + return x + h_ class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), - nn.Dropout(dropout) - ) + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) def forward(self, x, context=None, mask=None): h = self.heads @@ -183,31 +159,34 @@ class CrossAttention(nn.Module): k = self.to_k(context) v = self.to_v(context) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) - sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') + mask = rearrange(mask, "b ... -> b (...)") max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) + mask = repeat(mask, "b j -> (b h) () j", h=h) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of attn = sim.softmax(dim=-1) - out = torch.einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out = torch.einsum("b i j, b j d -> b i d", attn, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) return self.to_out(out) class BasicTransformerBlock(nn.Module): - def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): + def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True): super().__init__() - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, - heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.attn2 = CrossAttention( + query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) @@ -228,29 +207,23 @@ class SpatialTransformer(nn.Module): Then apply standard transformer action. Finally, reshape to image """ - def __init__(self, in_channels, n_heads, d_head, - depth=1, dropout=0., context_dim=None): + + def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None): super().__init__() self.in_channels = in_channels inner_dim = n_heads * d_head self.norm = Normalize(in_channels) - self.proj_in = nn.Conv2d(in_channels, - inner_dim, - kernel_size=1, - stride=1, - padding=0) + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) self.transformer_blocks = nn.ModuleList( - [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) - for d in range(depth)] + [ + BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth) + ] ) - self.proj_out = zero_module(nn.Conv2d(inner_dim, - in_channels, - kernel_size=1, - stride=1, - padding=0)) + self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) def forward(self, x, context=None): # note: if no context is given, cross-attention defaults to self-attention @@ -258,13 +231,14 @@ class SpatialTransformer(nn.Module): x_in = x x = self.norm(x) x = self.proj_in(x) - x = rearrange(x, 'b c h w -> b (h w) c') + x = rearrange(x, "b c h w -> b (h w) c") for block in self.transformer_blocks: x = block(x, context=context) - x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) x = self.proj_out(x) return x + x_in + def convert_module_to_f16(l): """ Convert primitive modules to float16. @@ -386,7 +360,7 @@ class AttentionPool2d(nn.Module): output_dim: int = None, ): super().__init__() - self.positional_embedding = nn.Parameter(torch.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.positional_embedding = nn.Parameter(torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.num_heads = embed_dim // num_heads_channels @@ -453,9 +427,7 @@ class Upsample(nn.Module): def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: - x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - ) + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") else: x = F.interpolate(x, scale_factor=2, mode="nearest") if self.use_conv: @@ -472,7 +444,7 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -480,9 +452,7 @@ class Downsample(nn.Module): self.dims = dims stride = 2 if dims != 3 else (1, 2, 2) if use_conv: - self.op = conv_nd( - dims, self.channels, self.out_channels, 3, stride=stride, padding=padding - ) + self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) @@ -558,17 +528,13 @@ class ResBlock(TimestepBlock): normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) - ), + zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1 - ) + self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) @@ -686,7 +652,7 @@ def count_flops_attn(model, _x, y): # We perform two matmuls with the same number of ops. # The first computes the weight matrix, the second computes # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial ** 2) * c + matmul_ops = 2 * b * (num_spatial**2) * c model.total_ops += torch.DoubleTensor([matmul_ops]) @@ -710,9 +676,7 @@ class QKVAttentionLegacy(nn.Module): ch = width // (3 * self.n_heads) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) - weight = torch.einsum( - "bct,bcs->bts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards + weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) a = torch.einsum("bts,bcs->bct", weight, v) return a.reshape(bs, -1, length) @@ -773,14 +737,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin): use_scale_shift_norm=False, resblock_updown=False, use_new_attention_order=False, - use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model legacy=True, ): super().__init__() - + # register all __init__ params with self.register self.register( image_size=image_size, @@ -810,19 +774,23 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ) if use_spatial_transformer: - assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." if context_dim is not None: - assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." if num_heads_upsample == -1: num_heads_upsample = num_heads if num_heads == -1: - assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set" if num_head_channels == -1: - assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + assert num_heads != -1, "Either num_heads or num_head_channels has to be set" self.image_size = image_size self.in_channels = in_channels @@ -852,11 +820,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): self.label_emb = nn.Embedding(num_classes, time_embed_dim) self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] + [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] ) self._feature_size = model_channels input_block_chans = [model_channels] @@ -883,7 +847,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels layers.append( AttentionBlock( @@ -892,7 +856,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin): num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( + ) + if not use_spatial_transformer + else SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim ) ) @@ -914,9 +880,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): down=True, ) if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) ) ) ch = out_ch @@ -930,7 +894,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels self.middle_block = TimestepEmbedSequential( ResBlock( @@ -947,9 +911,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin): num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim - ), + ) + if not use_spatial_transformer + else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim), ResBlock( ch, time_embed_dim, @@ -984,7 +948,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels layers.append( AttentionBlock( @@ -993,7 +957,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin): num_heads=num_heads_upsample, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( + ) + if not use_spatial_transformer + else SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim ) ) @@ -1024,10 +990,10 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( - normalization(ch), - conv_nd(dims, model_channels, n_embed, 1), - #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits - ) + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) def convert_to_fp16(self): """ @@ -1045,7 +1011,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. @@ -1108,7 +1074,7 @@ class EncoderUNetModel(nn.Module): use_new_attention_order=False, pool="adaptive", *args, - **kwargs + **kwargs, ): super().__init__() @@ -1137,11 +1103,7 @@ class EncoderUNetModel(nn.Module): ) self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] + [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] ) self._feature_size = model_channels input_block_chans = [model_channels] @@ -1189,9 +1151,7 @@ class EncoderUNetModel(nn.Module): down=True, ) if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) ) ) ch = out_ch @@ -1239,9 +1199,7 @@ class EncoderUNetModel(nn.Module): self.out = nn.Sequential( normalization(ch), nn.SiLU(), - AttentionPool2d( - (image_size // ds), ch, num_head_channels, out_channels - ), + AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels), ) elif pool == "spatial": self.out = nn.Sequential( @@ -1296,4 +1254,3 @@ class EncoderUNetModel(nn.Module): else: h = h.type(x.dtype) return self.out(h) - diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 7389b7f5..209647d1 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -20,10 +20,9 @@ from typing import Optional, Union from huggingface_hub import snapshot_download -from .utils import logging, DIFFUSERS_CACHE - from .configuration_utils import ConfigMixin from .dynamic_modules_utils import get_class_from_dynamic_module +from .utils import DIFFUSERS_CACHE, logging INDEX_FILE = "diffusion_model.pt" @@ -106,7 +105,7 @@ class DiffusionPipeline(ConfigMixin): @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): r""" - Add docstrings + Add docstrings """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) resume_download = kwargs.pop("resume_download", False) diff --git a/src/diffusers/schedulers/gaussian_ddpm.py b/src/diffusers/schedulers/gaussian_ddpm.py index 0bcf59d2..c3724a2d 100644 --- a/src/diffusers/schedulers/gaussian_ddpm.py +++ b/src/diffusers/schedulers/gaussian_ddpm.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch import math + +import torch from torch import nn from ..configuration_utils import ConfigMixin -from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar +from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule SAMPLING_CONFIG_NAME = "scheduler_config.json" diff --git a/src/diffusers/schedulers/glide_ddim.py b/src/diffusers/schedulers/glide_ddim.py index 91f62ea3..8b5d86bd 100644 --- a/src/diffusers/schedulers/glide_ddim.py +++ b/src/diffusers/schedulers/glide_ddim.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch import numpy as np +import torch from torch import nn from ..configuration_utils import ConfigMixin -from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar +from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule SAMPLING_CONFIG_NAME = "scheduler_config.json" @@ -26,12 +26,7 @@ class GlideDDIMScheduler(nn.Module, ConfigMixin): config_name = SAMPLING_CONFIG_NAME - def __init__( - self, - timesteps=1000, - beta_schedule="linear", - variance_type="fixed_large" - ): + def __init__(self, timesteps=1000, beta_schedule="linear", variance_type="fixed_large"): super().__init__() self.register( timesteps=timesteps, @@ -93,4 +88,4 @@ class GlideDDIMScheduler(nn.Module, ConfigMixin): return torch.randn(shape, generator=generator).to(device) def __len__(self): - return self.num_timesteps \ No newline at end of file + return self.num_timesteps diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 45b9f64e..7f25da44 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -5,6 +5,8 @@ # There's no way to ignore "F401 '...' imported but unused" warnings in this # module, but to preserve other warnings. So, don't check this module at all. +import os + # Copyright 2021 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from requests.exceptions import HTTPError -import os + hf_cache_home = os.path.expanduser( os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index b5f95e46..61e9e833 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -14,19 +14,19 @@ # limitations under the License. +import os import random import tempfile import unittest -import os from distutils.util import strtobool import torch from diffusers import GaussianDDPMScheduler, UNetModel -from diffusers.pipeline_utils import DiffusionPipeline from diffusers.configuration_utils import ConfigMixin -from models.vision.ddpm.modeling_ddpm import DDPM +from diffusers.pipeline_utils import DiffusionPipeline from models.vision.ddim.modeling_ddim import DDIM +from models.vision.ddpm.modeling_ddpm import DDPM global_rng = random.Random() @@ -85,7 +85,6 @@ class ConfigTester(unittest.TestCase): ConfigMixin.from_config("dummy_path") def test_save_load(self): - class SampleObject(ConfigMixin): config_name = "config.json" @@ -153,7 +152,6 @@ class ModelTesterMixin(unittest.TestCase): class SamplerTesterMixin(unittest.TestCase): - @slow def test_sample(self): generator = torch.manual_seed(0) @@ -163,15 +161,23 @@ class SamplerTesterMixin(unittest.TestCase): model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device) # 2. Sample gaussian noise - image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator) + image = scheduler.sample_noise( + (1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator + ) # 3. Denoise for t in reversed(range(len(scheduler))): # i) define coefficients for time step t clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t)) clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1) - image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t)) - clipped_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t)) + image_coeff = ( + (1 - scheduler.get_alpha_prod(t - 1)) + * torch.sqrt(scheduler.get_alpha(t)) + / (1 - scheduler.get_alpha_prod(t)) + ) + clipped_coeff = ( + torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t)) + ) # ii) predict noise residual with torch.no_grad(): @@ -201,7 +207,9 @@ class SamplerTesterMixin(unittest.TestCase): assert image.shape == (1, 3, 256, 256) image_slice = image[0, -1, -3:, -3:].cpu() - expected_slice = torch.tensor([-0.1636, -0.1765, -0.1968, -0.1338, -0.1432, -0.1622, -0.1793, -0.2001, -0.2280]) + expected_slice = torch.tensor( + [-0.1636, -0.1765, -0.1968, -0.1338, -0.1432, -0.1622, -0.1793, -0.2001, -0.2280] + ) assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 def test_sample_fast(self): @@ -212,15 +220,23 @@ class SamplerTesterMixin(unittest.TestCase): model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device) # 2. Sample gaussian noise - image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator) + image = scheduler.sample_noise( + (1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator + ) # 3. Denoise for t in reversed(range(len(scheduler))): # i) define coefficients for time step t clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t)) clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1) - image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t)) - clipped_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t)) + image_coeff = ( + (1 - scheduler.get_alpha_prod(t - 1)) + * torch.sqrt(scheduler.get_alpha(t)) + / (1 - scheduler.get_alpha_prod(t)) + ) + clipped_coeff = ( + torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t)) + ) # ii) predict noise residual with torch.no_grad(): @@ -246,7 +262,6 @@ class SamplerTesterMixin(unittest.TestCase): class PipelineTesterMixin(unittest.TestCase): - def test_from_pretrained_save_pretrained(self): # 1. Load models model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32) @@ -309,5 +324,7 @@ class PipelineTesterMixin(unittest.TestCase): image_slice = image[0, -1, -3:, -3:].cpu() assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor([-0.7383, -0.7385, -0.7298, -0.7364, -0.7414, -0.7239, -0.6737, -0.6813, -0.7068]) + expected_slice = torch.tensor( + [-0.7383, -0.7385, -0.7298, -0.7364, -0.7414, -0.7239, -0.6737, -0.6813, -0.7068] + ) assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2