misc fixes (#2282)
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
648090e26e
commit
fd5c3c09af
|
@ -48,7 +48,10 @@ if __name__ == "__main__":
|
|||
"--pipeline_type",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The pipeline type. If `None` pipeline will be automatically inferred.",
|
||||
help=(
|
||||
"The pipeline type. One of 'FrozenOpenCLIPEmbedder', 'FrozenCLIPEmbedder', 'PaintByExample'"
|
||||
". If `None` pipeline will be automatically inferred."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_size",
|
||||
|
@ -65,7 +68,7 @@ if __name__ == "__main__":
|
|||
type=str,
|
||||
help=(
|
||||
"The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable"
|
||||
" Siffusion v2 Base. Use 'v-prediction' for Stable Diffusion v2."
|
||||
" Diffusion v2 Base. Use 'v_prediction' for Stable Diffusion v2."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -79,8 +82,7 @@ if __name__ == "__main__":
|
|||
)
|
||||
parser.add_argument(
|
||||
"--upcast_attention",
|
||||
default=False,
|
||||
type=bool,
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether the attention computation should always be upcasted. This is necessary when running stable"
|
||||
" diffusion 2.1."
|
||||
|
@ -111,5 +113,6 @@ if __name__ == "__main__":
|
|||
num_in_channels=args.num_in_channels,
|
||||
upcast_attention=args.upcast_attention,
|
||||
from_safetensors=args.from_safetensors,
|
||||
device=args.device,
|
||||
)
|
||||
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
||||
|
|
|
@ -13,17 +13,12 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
import warnings
|
||||
import numpy as np # noqa: E402
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...schedulers.scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
import numpy as np # noqa: E402
|
||||
|
||||
|
||||
try:
|
||||
import librosa # noqa: E402
|
||||
|
||||
|
|
|
@ -39,10 +39,13 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBe
|
|||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
|
||||
from ...utils import is_omegaconf_available, is_safetensors_available
|
||||
from ...utils import is_omegaconf_available, is_safetensors_available, logging
|
||||
from ...utils.import_utils import BACKENDS_MAPPING
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
"""
|
||||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||
|
@ -801,11 +804,11 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||
corresponding to the original architecture. If `None`, will be
|
||||
automatically inferred by looking for a key that only exists in SD2.0 models.
|
||||
:param image_size: The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable
|
||||
Siffusion v2
|
||||
Diffusion v2
|
||||
Base. Use 768 for Stable Diffusion v2.
|
||||
:param prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion
|
||||
v1.X and Stable
|
||||
Siffusion v2 Base. Use `'v-prediction'` for Stable Diffusion v2.
|
||||
Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.
|
||||
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
|
||||
inferred. :param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
|
||||
"euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of
|
||||
|
@ -820,6 +823,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||
`checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A
|
||||
StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
||||
"""
|
||||
if prediction_type == "v-prediction":
|
||||
prediction_type = "v_prediction"
|
||||
|
||||
if not is_omegaconf_available():
|
||||
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
||||
|
@ -957,6 +962,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||
# Convert the text model.
|
||||
if model_type is None:
|
||||
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
|
||||
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||
|
|
|
@ -305,7 +305,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||
|
||||
if eta > 0:
|
||||
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
|
||||
device = model_output.device
|
||||
if variance_noise is not None and generator is not None:
|
||||
raise ValueError(
|
||||
|
|
|
@ -106,8 +106,11 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
|
|||
clip_sample: bool = True,
|
||||
clip_sample_range: Optional[float] = 1.0,
|
||||
prediction_type: str = "epsilon",
|
||||
beta_schedule: str = "squaredcos_cap_v2",
|
||||
):
|
||||
# beta scheduler is "squaredcos_cap_v2"
|
||||
if beta_schedule != "squaredcos_cap_v2":
|
||||
raise ValueError("UnCLIPScheduler only supports `beta_schedule`: 'squaredcos_cap_v2'")
|
||||
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
|
|
|
@ -17,23 +17,36 @@ import requests
|
|||
from packaging import version
|
||||
|
||||
from .import_utils import is_flax_available, is_onnx_available, is_torch_available
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse(
|
||||
"1.12"
|
||||
)
|
||||
if "DIFFUSERS_TEST_DEVICE" in os.environ:
|
||||
torch_device = os.environ["DIFFUSERS_TEST_DEVICE"]
|
||||
|
||||
if is_torch_higher_equal_than_1_12:
|
||||
# Some builds of torch 1.12 don't have the mps backend registered. See #892 for more details
|
||||
mps_backend_registered = hasattr(torch.backends, "mps")
|
||||
torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device
|
||||
available_backends = ["cuda", "cpu", "mps"]
|
||||
if torch_device not in available_backends:
|
||||
raise ValueError(
|
||||
f"unknown torch backend for diffusers tests: {torch_device}. Available backends are:"
|
||||
f" {available_backends}"
|
||||
)
|
||||
logger.info(f"torch_device overrode to {torch_device}")
|
||||
else:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
is_torch_higher_equal_than_1_12 = version.parse(
|
||||
version.parse(torch.__version__).base_version
|
||||
) >= version.parse("1.12")
|
||||
|
||||
if is_torch_higher_equal_than_1_12:
|
||||
# Some builds of torch 1.12 don't have the mps backend registered. See #892 for more details
|
||||
mps_backend_registered = hasattr(torch.backends, "mps")
|
||||
torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device
|
||||
|
||||
|
||||
def torch_all_close(a, b, *args, **kwargs):
|
||||
|
|
|
@ -30,6 +30,7 @@ from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_diff
|
|||
|
||||
class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = UnCLIPPipeline
|
||||
test_xformers_attention = False
|
||||
|
||||
required_optional_params = [
|
||||
"generator",
|
||||
|
|
|
@ -259,6 +259,7 @@ class PipelineTesterMixin:
|
|||
# Taking the median of the largest <n> differences
|
||||
# is resilient to outliers
|
||||
diff = np.abs(output_batch[0][0] - output[0][0])
|
||||
diff = diff.flatten()
|
||||
diff.sort()
|
||||
max_diff = np.median(diff[-5:])
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue