cancel einops

This commit is contained in:
Patrick von Platen 2022-06-27 15:39:41 +00:00
parent 4e08e0ca42
commit 932ce05d97
5 changed files with 13 additions and 24 deletions

View File

@ -173,7 +173,7 @@ if __name__ == "__main__":
parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--warmup_steps", type=int, default=500) parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument("--ema_inv_gamma", type=float, default=1.0) parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
parser.add_argument("--ema_power", type=float, default=3/4) parser.add_argument("--ema_power", type=float, default=3 / 4)
parser.add_argument("--ema_max_decay", type=float, default=0.999) parser.add_argument("--ema_max_decay", type=float, default=0.999)
parser.add_argument("--push_to_hub", action="store_true") parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--hub_token", type=str, default=None) parser.add_argument("--hub_token", type=str, default=None)

View File

@ -13,13 +13,6 @@ from .embeddings import get_timestep_embedding
from .resnet import Upsample from .resnet import Upsample
# try:
# from einops import rearrange, repeat
# except:
# print("Einops is not installed")
# pass
def exists(val): def exists(val):
return val is not None return val is not None

View File

@ -17,7 +17,6 @@ from ..modeling_utils import ModelMixin
# pass # pass
class SinusoidalPosEmb(nn.Module): class SinusoidalPosEmb(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()

View File

@ -11,11 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Union
import numpy as np import numpy as np
import torch import torch
from typing import Union
SCHEDULER_CONFIG_NAME = "scheduler_config.json" SCHEDULER_CONFIG_NAME = "scheduler_config.json"
@ -53,20 +53,16 @@ class SchedulerMixin:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def match_shape( def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]):
self,
values: Union[np.ndarray, torch.Tensor],
broadcast_array: Union[np.ndarray, torch.Tensor]
):
""" """
Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims. Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
Args: Args:
timesteps: an array or tensor of values to extract. timesteps: an array or tensor of values to extract.
broadcast_array: an array with a larger shape of K dimensions with the batch broadcast_array: an array with a larger shape of K dimensions with the batch
dimension equal to the length of timesteps. dimension equal to the length of timesteps.
Returns: Returns:
a tensor of shape [batch_size, 1, ...] where the shape has K dims. a tensor of shape [batch_size, 1, ...] where the shape has K dims.
""" """
tensor_format = getattr(self, "tensor_format", "pt") tensor_format = getattr(self, "tensor_format", "pt")

View File

@ -21,7 +21,8 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from diffusers import ( # GradTTSPipeline, from diffusers import (
GradTTSPipeline,
BDDMPipeline, BDDMPipeline,
DDIMPipeline, DDIMPipeline,
DDIMScheduler, DDIMScheduler,