cancel einops
This commit is contained in:
parent
4e08e0ca42
commit
932ce05d97
|
@ -173,7 +173,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--lr", type=float, default=1e-4)
|
parser.add_argument("--lr", type=float, default=1e-4)
|
||||||
parser.add_argument("--warmup_steps", type=int, default=500)
|
parser.add_argument("--warmup_steps", type=int, default=500)
|
||||||
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
|
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
|
||||||
parser.add_argument("--ema_power", type=float, default=3/4)
|
parser.add_argument("--ema_power", type=float, default=3 / 4)
|
||||||
parser.add_argument("--ema_max_decay", type=float, default=0.999)
|
parser.add_argument("--ema_max_decay", type=float, default=0.999)
|
||||||
parser.add_argument("--push_to_hub", action="store_true")
|
parser.add_argument("--push_to_hub", action="store_true")
|
||||||
parser.add_argument("--hub_token", type=str, default=None)
|
parser.add_argument("--hub_token", type=str, default=None)
|
||||||
|
|
|
@ -13,13 +13,6 @@ from .embeddings import get_timestep_embedding
|
||||||
from .resnet import Upsample
|
from .resnet import Upsample
|
||||||
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# from einops import rearrange, repeat
|
|
||||||
# except:
|
|
||||||
# print("Einops is not installed")
|
|
||||||
# pass
|
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,6 @@ from ..modeling_utils import ModelMixin
|
||||||
# pass
|
# pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SinusoidalPosEmb(nn.Module):
|
class SinusoidalPosEmb(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -11,11 +11,11 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
|
|
||||||
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
||||||
|
|
||||||
|
@ -53,20 +53,16 @@ class SchedulerMixin:
|
||||||
|
|
||||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||||
|
|
||||||
def match_shape(
|
def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]):
|
||||||
self,
|
|
||||||
values: Union[np.ndarray, torch.Tensor],
|
|
||||||
broadcast_array: Union[np.ndarray, torch.Tensor]
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
|
Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
timesteps: an array or tensor of values to extract.
|
timesteps: an array or tensor of values to extract.
|
||||||
broadcast_array: an array with a larger shape of K dimensions with the batch
|
broadcast_array: an array with a larger shape of K dimensions with the batch
|
||||||
dimension equal to the length of timesteps.
|
dimension equal to the length of timesteps.
|
||||||
Returns:
|
Returns:
|
||||||
a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tensor_format = getattr(self, "tensor_format", "pt")
|
tensor_format = getattr(self, "tensor_format", "pt")
|
||||||
|
|
|
@ -21,7 +21,8 @@ import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import ( # GradTTSPipeline,
|
from diffusers import (
|
||||||
|
GradTTSPipeline,
|
||||||
BDDMPipeline,
|
BDDMPipeline,
|
||||||
DDIMPipeline,
|
DDIMPipeline,
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
|
|
Loading…
Reference in New Issue