Bump to 0.10.0.dev0 + deprecations (#1490)

This commit is contained in:
Anton Lozhkov 2022-11-30 15:27:56 +01:00 committed by GitHub
parent eeeb28a9ad
commit 999044596a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 20 additions and 155 deletions

View File

@ -14,7 +14,6 @@ from datasets import load_dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel, __version__ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel, __version__
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import deprecate
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, whoami
from packaging import version from packaging import version
from torchvision.transforms import ( from torchvision.transforms import (
@ -417,11 +416,7 @@ def main(args):
scheduler=noise_scheduler, scheduler=noise_scheduler,
) )
deprecate("todo: remove this check", "0.10.0", "when the most used version is >= 0.8.0") generator = torch.Generator(device=pipeline.device).manual_seed(0)
if diffusers_version < version.parse("0.8.0"):
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=pipeline.device).manual_seed(0)
# run pipeline in inference (sample random noise and denoise) # run pipeline in inference (sample random noise and denoise)
images = pipeline( images = pipeline(
generator=generator, generator=generator,

View File

@ -214,7 +214,7 @@ install_requires = [
setup( setup(
name="diffusers", name="diffusers",
version="0.9.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) version="0.10.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="Diffusers", description="Diffusers",
long_description=open("README.md", "r", encoding="utf-8").read(), long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",

View File

@ -9,7 +9,7 @@ from .utils import (
) )
__version__ = "0.9.0" __version__ = "0.10.0.dev0"
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .onnx_utils import OnnxRuntimeModel from .onnx_utils import OnnxRuntimeModel

View File

@ -15,16 +15,15 @@
import os import os
import shutil
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from uuid import uuid4 from uuid import uuid4
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, whoami
from . import __version__ from . import __version__
from .utils import ENV_VARS_TRUE_VALUES, deprecate, logging from .utils import ENV_VARS_TRUE_VALUES, logging
from .utils.import_utils import ( from .utils.import_utils import (
_flax_version, _flax_version,
_jax_version, _jax_version,
@ -83,121 +82,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
return f"{organization}/{model_id}" return f"{organization}/{model_id}"
def init_git_repo(args, at_init: bool = False):
"""
Args:
Initializes a git repo in `args.hub_model_id`.
at_init (`bool`, *optional*, defaults to `False`):
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
"""
deprecation_message = (
"Please use `huggingface_hub.Repository`. "
"See `examples/unconditional_image_generation/train_unconditional.py` for an example."
)
deprecate("init_git_repo()", "0.10.0", deprecation_message)
if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
return
hub_token = args.hub_token if hasattr(args, "hub_token") else None
use_auth_token = True if hub_token is None else hub_token
if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
repo_name = Path(args.output_dir).absolute().name
else:
repo_name = args.hub_model_id
if "/" not in repo_name:
repo_name = get_full_repo_name(repo_name, token=hub_token)
try:
repo = Repository(
args.output_dir,
clone_from=repo_name,
use_auth_token=use_auth_token,
private=args.hub_private_repo,
)
except EnvironmentError:
if args.overwrite_output_dir and at_init:
# Try again after wiping output_dir
shutil.rmtree(args.output_dir)
repo = Repository(
args.output_dir,
clone_from=repo_name,
use_auth_token=use_auth_token,
)
else:
raise
repo.git_pull()
# By default, ignore the checkpoint folders
if not os.path.exists(os.path.join(args.output_dir, ".gitignore")):
with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
writer.writelines(["checkpoint-*/"])
return repo
def push_to_hub(
args,
pipeline,
repo: Repository,
commit_message: Optional[str] = "End of training",
blocking: bool = True,
**kwargs,
) -> str:
"""
Parameters:
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
commit_message (`str`, *optional*, defaults to `"End of training"`):
Message to commit while pushing.
blocking (`bool`, *optional*, defaults to `True`):
Whether the function should return only when the `git push` has finished.
kwargs:
Additional keyword arguments passed along to [`create_model_card`].
Returns:
The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the
commit and an object to track the progress of the commit if `blocking=True`
"""
deprecation_message = (
"Please use `huggingface_hub.Repository` and `Repository.push_to_hub()`. "
"See `examples/unconditional_image_generation/train_unconditional.py` for an example."
)
deprecate("push_to_hub()", "0.10.0", deprecation_message)
if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
model_name = Path(args.output_dir).name
else:
model_name = args.hub_model_id.split("/")[-1]
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving pipeline checkpoint to {output_dir}")
pipeline.save_pretrained(output_dir)
# Only push from one node.
if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
return
# Cancel any async push in progress if blocking=True. The commits will all be pushed together.
if (
blocking
and len(repo.command_queue) > 0
and repo.command_queue[-1] is not None
and not repo.command_queue[-1].is_done
):
repo.command_queue[-1]._process.kill()
git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True)
# push separately the model card to be independent from the rest of the model
create_model_card(args, model_name=model_name)
try:
repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True)
except EnvironmentError as exc:
logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
return git_head_commit_url
def create_model_card(args, model_name): def create_model_card(args, model_name):
if not is_modelcards_available: if not is_modelcards_available:
raise ValueError( raise ValueError(

View File

@ -666,20 +666,6 @@ class ModelMixin(torch.nn.Module):
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
"""
Recursively unwraps a model from potential containers (as used in distributed training).
Args:
model (`torch.nn.Module`): The model to unwrap.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
else:
return model
def _get_model_file( def _get_model_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
*, *,

View File

@ -73,7 +73,7 @@ class DDPMPipeline(DiffusionPipeline):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`." " DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
) )
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
if predict_epsilon is not None: if predict_epsilon is not None:
new_config = dict(self.scheduler.config) new_config = dict(self.scheduler.config)

View File

@ -134,7 +134,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`." " DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
) )
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
if predict_epsilon is not None: if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")

View File

@ -138,7 +138,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`." " FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
) )
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
if predict_epsilon is not None: if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")

View File

@ -125,7 +125,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`." " DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
) )
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
if predict_epsilon is not None: if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
@ -255,7 +255,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`." " DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
) )
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
if predict_epsilon is not None: if predict_epsilon is not None:
new_config = dict(self.config) new_config = dict(self.config)
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample" new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"

View File

@ -132,7 +132,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`." " FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
) )
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
if predict_epsilon is not None: if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
@ -239,7 +239,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`." " FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
) )
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
if predict_epsilon is not None: if predict_epsilon is not None:
new_config = dict(self.config) new_config = dict(self.config)
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample" new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"

View File

@ -142,7 +142,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`." " DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
) )
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
if predict_epsilon is not None: if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")

View File

@ -177,7 +177,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`." " FlaxDPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
) )
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
if predict_epsilon is not None: if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")

View File

@ -69,7 +69,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_deprecated_predict_epsilon(self): def test_inference_deprecated_predict_epsilon(self):
deprecate("remove this test", "0.10.0", "remove") deprecate("remove this test", "0.11.0", "remove")
unet = self.dummy_uncond_unet unet = self.dummy_uncond_unet
scheduler = DDPMScheduler(predict_epsilon=False) scheduler = DDPMScheduler(predict_epsilon=False)

View File

@ -203,7 +203,7 @@ class ConfigTester(unittest.TestCase):
ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88) ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88)
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
deprecate("remove this case", "0.10.0", "remove") deprecate("remove this case", "0.11.0", "remove")
ddpm_3 = DDPMScheduler.from_pretrained( ddpm_3 = DDPMScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", "hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="scheduler", subfolder="scheduler",

View File

@ -639,12 +639,12 @@ class DDPMSchedulerTest(SchedulerCommonTest):
self.check_over_configs(prediction_type=prediction_type) self.check_over_configs(prediction_type=prediction_type)
def test_deprecated_predict_epsilon(self): def test_deprecated_predict_epsilon(self):
deprecate("remove this test", "0.10.0", "remove") deprecate("remove this test", "0.11.0", "remove")
for predict_epsilon in [True, False]: for predict_epsilon in [True, False]:
self.check_over_configs(predict_epsilon=predict_epsilon) self.check_over_configs(predict_epsilon=predict_epsilon)
def test_deprecated_epsilon(self): def test_deprecated_epsilon(self):
deprecate("remove this test", "0.10.0", "remove") deprecate("remove this test", "0.11.0", "remove")
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()

View File

@ -626,12 +626,12 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
self.check_over_configs(prediction_type=prediction_type) self.check_over_configs(prediction_type=prediction_type)
def test_deprecated_predict_epsilon(self): def test_deprecated_predict_epsilon(self):
deprecate("remove this test", "0.10.0", "remove") deprecate("remove this test", "0.11.0", "remove")
for predict_epsilon in [True, False]: for predict_epsilon in [True, False]:
self.check_over_configs(predict_epsilon=predict_epsilon) self.check_over_configs(predict_epsilon=predict_epsilon)
def test_deprecated_predict_epsilon_to_prediction_type(self): def test_deprecated_predict_epsilon_to_prediction_type(self):
deprecate("remove this test", "0.10.0", "remove") deprecate("remove this test", "0.11.0", "remove")
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(predict_epsilon=True) scheduler_config = self.get_scheduler_config(predict_epsilon=True)
scheduler = scheduler_class.from_config(scheduler_config) scheduler = scheduler_class.from_config(scheduler_config)