Bump to 0.10.0.dev0 + deprecations (#1490)
This commit is contained in:
parent
eeeb28a9ad
commit
999044596a
|
@ -14,7 +14,6 @@ from datasets import load_dataset
|
||||||
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel, __version__
|
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel, __version__
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from diffusers.training_utils import EMAModel
|
from diffusers.training_utils import EMAModel
|
||||||
from diffusers.utils import deprecate
|
|
||||||
from huggingface_hub import HfFolder, Repository, whoami
|
from huggingface_hub import HfFolder, Repository, whoami
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torchvision.transforms import (
|
from torchvision.transforms import (
|
||||||
|
@ -417,11 +416,7 @@ def main(args):
|
||||||
scheduler=noise_scheduler,
|
scheduler=noise_scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
deprecate("todo: remove this check", "0.10.0", "when the most used version is >= 0.8.0")
|
generator = torch.Generator(device=pipeline.device).manual_seed(0)
|
||||||
if diffusers_version < version.parse("0.8.0"):
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=pipeline.device).manual_seed(0)
|
|
||||||
# run pipeline in inference (sample random noise and denoise)
|
# run pipeline in inference (sample random noise and denoise)
|
||||||
images = pipeline(
|
images = pipeline(
|
||||||
generator=generator,
|
generator=generator,
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -214,7 +214,7 @@ install_requires = [
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="diffusers",
|
name="diffusers",
|
||||||
version="0.9.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
version="0.10.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||||
description="Diffusers",
|
description="Diffusers",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
|
|
@ -9,7 +9,7 @@ from .utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.9.0"
|
__version__ = "0.10.0.dev0"
|
||||||
|
|
||||||
from .configuration_utils import ConfigMixin
|
from .configuration_utils import ConfigMixin
|
||||||
from .onnx_utils import OnnxRuntimeModel
|
from .onnx_utils import OnnxRuntimeModel
|
||||||
|
|
|
@ -15,16 +15,15 @@
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from huggingface_hub import HfFolder, Repository, whoami
|
from huggingface_hub import HfFolder, whoami
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from .utils import ENV_VARS_TRUE_VALUES, deprecate, logging
|
from .utils import ENV_VARS_TRUE_VALUES, logging
|
||||||
from .utils.import_utils import (
|
from .utils.import_utils import (
|
||||||
_flax_version,
|
_flax_version,
|
||||||
_jax_version,
|
_jax_version,
|
||||||
|
@ -83,121 +82,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
||||||
return f"{organization}/{model_id}"
|
return f"{organization}/{model_id}"
|
||||||
|
|
||||||
|
|
||||||
def init_git_repo(args, at_init: bool = False):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
Initializes a git repo in `args.hub_model_id`.
|
|
||||||
at_init (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
|
|
||||||
and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
|
|
||||||
"""
|
|
||||||
deprecation_message = (
|
|
||||||
"Please use `huggingface_hub.Repository`. "
|
|
||||||
"See `examples/unconditional_image_generation/train_unconditional.py` for an example."
|
|
||||||
)
|
|
||||||
deprecate("init_git_repo()", "0.10.0", deprecation_message)
|
|
||||||
|
|
||||||
if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
|
|
||||||
return
|
|
||||||
hub_token = args.hub_token if hasattr(args, "hub_token") else None
|
|
||||||
use_auth_token = True if hub_token is None else hub_token
|
|
||||||
if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
|
|
||||||
repo_name = Path(args.output_dir).absolute().name
|
|
||||||
else:
|
|
||||||
repo_name = args.hub_model_id
|
|
||||||
if "/" not in repo_name:
|
|
||||||
repo_name = get_full_repo_name(repo_name, token=hub_token)
|
|
||||||
|
|
||||||
try:
|
|
||||||
repo = Repository(
|
|
||||||
args.output_dir,
|
|
||||||
clone_from=repo_name,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
private=args.hub_private_repo,
|
|
||||||
)
|
|
||||||
except EnvironmentError:
|
|
||||||
if args.overwrite_output_dir and at_init:
|
|
||||||
# Try again after wiping output_dir
|
|
||||||
shutil.rmtree(args.output_dir)
|
|
||||||
repo = Repository(
|
|
||||||
args.output_dir,
|
|
||||||
clone_from=repo_name,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
repo.git_pull()
|
|
||||||
|
|
||||||
# By default, ignore the checkpoint folders
|
|
||||||
if not os.path.exists(os.path.join(args.output_dir, ".gitignore")):
|
|
||||||
with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
|
|
||||||
writer.writelines(["checkpoint-*/"])
|
|
||||||
|
|
||||||
return repo
|
|
||||||
|
|
||||||
|
|
||||||
def push_to_hub(
|
|
||||||
args,
|
|
||||||
pipeline,
|
|
||||||
repo: Repository,
|
|
||||||
commit_message: Optional[str] = "End of training",
|
|
||||||
blocking: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Parameters:
|
|
||||||
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
|
|
||||||
commit_message (`str`, *optional*, defaults to `"End of training"`):
|
|
||||||
Message to commit while pushing.
|
|
||||||
blocking (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether the function should return only when the `git push` has finished.
|
|
||||||
kwargs:
|
|
||||||
Additional keyword arguments passed along to [`create_model_card`].
|
|
||||||
Returns:
|
|
||||||
The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the
|
|
||||||
commit and an object to track the progress of the commit if `blocking=True`
|
|
||||||
"""
|
|
||||||
deprecation_message = (
|
|
||||||
"Please use `huggingface_hub.Repository` and `Repository.push_to_hub()`. "
|
|
||||||
"See `examples/unconditional_image_generation/train_unconditional.py` for an example."
|
|
||||||
)
|
|
||||||
deprecate("push_to_hub()", "0.10.0", deprecation_message)
|
|
||||||
|
|
||||||
if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
|
|
||||||
model_name = Path(args.output_dir).name
|
|
||||||
else:
|
|
||||||
model_name = args.hub_model_id.split("/")[-1]
|
|
||||||
|
|
||||||
output_dir = args.output_dir
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
logger.info(f"Saving pipeline checkpoint to {output_dir}")
|
|
||||||
pipeline.save_pretrained(output_dir)
|
|
||||||
|
|
||||||
# Only push from one node.
|
|
||||||
if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Cancel any async push in progress if blocking=True. The commits will all be pushed together.
|
|
||||||
if (
|
|
||||||
blocking
|
|
||||||
and len(repo.command_queue) > 0
|
|
||||||
and repo.command_queue[-1] is not None
|
|
||||||
and not repo.command_queue[-1].is_done
|
|
||||||
):
|
|
||||||
repo.command_queue[-1]._process.kill()
|
|
||||||
|
|
||||||
git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True)
|
|
||||||
# push separately the model card to be independent from the rest of the model
|
|
||||||
create_model_card(args, model_name=model_name)
|
|
||||||
try:
|
|
||||||
repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True)
|
|
||||||
except EnvironmentError as exc:
|
|
||||||
logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
|
|
||||||
|
|
||||||
return git_head_commit_url
|
|
||||||
|
|
||||||
|
|
||||||
def create_model_card(args, model_name):
|
def create_model_card(args, model_name):
|
||||||
if not is_modelcards_available:
|
if not is_modelcards_available:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -666,20 +666,6 @@ class ModelMixin(torch.nn.Module):
|
||||||
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
||||||
|
|
||||||
|
|
||||||
def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
|
|
||||||
"""
|
|
||||||
Recursively unwraps a model from potential containers (as used in distributed training).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (`torch.nn.Module`): The model to unwrap.
|
|
||||||
"""
|
|
||||||
# since there could be multiple levels of wrapping, unwrap recursively
|
|
||||||
if hasattr(model, "module"):
|
|
||||||
return unwrap_model(model.module)
|
|
||||||
else:
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def _get_model_file(
|
def _get_model_file(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
*,
|
*,
|
||||||
|
|
|
@ -73,7 +73,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||||
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||||
)
|
)
|
||||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
|
||||||
|
|
||||||
if predict_epsilon is not None:
|
if predict_epsilon is not None:
|
||||||
new_config = dict(self.scheduler.config)
|
new_config = dict(self.scheduler.config)
|
||||||
|
|
|
@ -134,7 +134,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||||
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||||
)
|
)
|
||||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
|
||||||
if predict_epsilon is not None:
|
if predict_epsilon is not None:
|
||||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||||
|
|
||||||
|
|
|
@ -138,7 +138,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||||
" FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
" FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||||
)
|
)
|
||||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
|
||||||
if predict_epsilon is not None:
|
if predict_epsilon is not None:
|
||||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||||
|
|
||||||
|
|
|
@ -125,7 +125,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||||
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||||
)
|
)
|
||||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
|
||||||
if predict_epsilon is not None:
|
if predict_epsilon is not None:
|
||||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||||
|
|
||||||
|
@ -255,7 +255,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||||
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||||
)
|
)
|
||||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
|
||||||
if predict_epsilon is not None:
|
if predict_epsilon is not None:
|
||||||
new_config = dict(self.config)
|
new_config = dict(self.config)
|
||||||
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
|
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
|
||||||
|
|
|
@ -132,7 +132,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||||
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||||
)
|
)
|
||||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
|
||||||
if predict_epsilon is not None:
|
if predict_epsilon is not None:
|
||||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||||
|
|
||||||
|
@ -239,7 +239,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||||
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||||
)
|
)
|
||||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
|
||||||
if predict_epsilon is not None:
|
if predict_epsilon is not None:
|
||||||
new_config = dict(self.config)
|
new_config = dict(self.config)
|
||||||
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
|
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
|
||||||
|
|
|
@ -142,7 +142,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||||
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||||
)
|
)
|
||||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
|
||||||
if predict_epsilon is not None:
|
if predict_epsilon is not None:
|
||||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||||
|
|
||||||
|
|
|
@ -177,7 +177,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||||
" FlaxDPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
" FlaxDPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||||
)
|
)
|
||||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
|
||||||
if predict_epsilon is not None:
|
if predict_epsilon is not None:
|
||||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||||
|
|
||||||
|
|
|
@ -69,7 +69,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_inference_deprecated_predict_epsilon(self):
|
def test_inference_deprecated_predict_epsilon(self):
|
||||||
deprecate("remove this test", "0.10.0", "remove")
|
deprecate("remove this test", "0.11.0", "remove")
|
||||||
unet = self.dummy_uncond_unet
|
unet = self.dummy_uncond_unet
|
||||||
scheduler = DDPMScheduler(predict_epsilon=False)
|
scheduler = DDPMScheduler(predict_epsilon=False)
|
||||||
|
|
||||||
|
|
|
@ -203,7 +203,7 @@ class ConfigTester(unittest.TestCase):
|
||||||
ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88)
|
ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88)
|
||||||
|
|
||||||
with CaptureLogger(logger) as cap_logger:
|
with CaptureLogger(logger) as cap_logger:
|
||||||
deprecate("remove this case", "0.10.0", "remove")
|
deprecate("remove this case", "0.11.0", "remove")
|
||||||
ddpm_3 = DDPMScheduler.from_pretrained(
|
ddpm_3 = DDPMScheduler.from_pretrained(
|
||||||
"hf-internal-testing/tiny-stable-diffusion-torch",
|
"hf-internal-testing/tiny-stable-diffusion-torch",
|
||||||
subfolder="scheduler",
|
subfolder="scheduler",
|
||||||
|
|
|
@ -639,12 +639,12 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||||
self.check_over_configs(prediction_type=prediction_type)
|
self.check_over_configs(prediction_type=prediction_type)
|
||||||
|
|
||||||
def test_deprecated_predict_epsilon(self):
|
def test_deprecated_predict_epsilon(self):
|
||||||
deprecate("remove this test", "0.10.0", "remove")
|
deprecate("remove this test", "0.11.0", "remove")
|
||||||
for predict_epsilon in [True, False]:
|
for predict_epsilon in [True, False]:
|
||||||
self.check_over_configs(predict_epsilon=predict_epsilon)
|
self.check_over_configs(predict_epsilon=predict_epsilon)
|
||||||
|
|
||||||
def test_deprecated_epsilon(self):
|
def test_deprecated_epsilon(self):
|
||||||
deprecate("remove this test", "0.10.0", "remove")
|
deprecate("remove this test", "0.11.0", "remove")
|
||||||
scheduler_class = self.scheduler_classes[0]
|
scheduler_class = self.scheduler_classes[0]
|
||||||
scheduler_config = self.get_scheduler_config()
|
scheduler_config = self.get_scheduler_config()
|
||||||
|
|
||||||
|
|
|
@ -626,12 +626,12 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
|
||||||
self.check_over_configs(prediction_type=prediction_type)
|
self.check_over_configs(prediction_type=prediction_type)
|
||||||
|
|
||||||
def test_deprecated_predict_epsilon(self):
|
def test_deprecated_predict_epsilon(self):
|
||||||
deprecate("remove this test", "0.10.0", "remove")
|
deprecate("remove this test", "0.11.0", "remove")
|
||||||
for predict_epsilon in [True, False]:
|
for predict_epsilon in [True, False]:
|
||||||
self.check_over_configs(predict_epsilon=predict_epsilon)
|
self.check_over_configs(predict_epsilon=predict_epsilon)
|
||||||
|
|
||||||
def test_deprecated_predict_epsilon_to_prediction_type(self):
|
def test_deprecated_predict_epsilon_to_prediction_type(self):
|
||||||
deprecate("remove this test", "0.10.0", "remove")
|
deprecate("remove this test", "0.11.0", "remove")
|
||||||
for scheduler_class in self.scheduler_classes:
|
for scheduler_class in self.scheduler_classes:
|
||||||
scheduler_config = self.get_scheduler_config(predict_epsilon=True)
|
scheduler_config = self.get_scheduler_config(predict_epsilon=True)
|
||||||
scheduler = scheduler_class.from_config(scheduler_config)
|
scheduler = scheduler_class.from_config(scheduler_config)
|
||||||
|
|
Loading…
Reference in New Issue