make style 2 - sorry
This commit is contained in:
parent
97ef5e0665
commit
04ad948673
|
@ -336,7 +336,10 @@ class TextualInversionDataset(Dataset):
|
|||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
(h, w,) = (
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
|
|
|
@ -432,7 +432,10 @@ class TextualInversionDataset(Dataset):
|
|||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
(h, w,) = (
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
|
|
|
@ -306,7 +306,10 @@ class TextualInversionDataset(Dataset):
|
|||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
(h, w,) = (
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
|
|
|
@ -94,8 +94,10 @@ class AttentionBlock(nn.Module):
|
|||
if use_memory_efficient_attention_xformers:
|
||||
if not is_xformers_available():
|
||||
raise ModuleNotFoundError(
|
||||
(
|
||||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
||||
" xformers",
|
||||
" xformers"
|
||||
),
|
||||
name="xformers",
|
||||
)
|
||||
elif not torch.cuda.is_available():
|
||||
|
|
|
@ -111,8 +111,10 @@ class CrossAttention(nn.Module):
|
|||
)
|
||||
elif not is_xformers_available():
|
||||
raise ModuleNotFoundError(
|
||||
(
|
||||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
||||
" xformers",
|
||||
" xformers"
|
||||
),
|
||||
name="xformers",
|
||||
)
|
||||
elif not torch.cuda.is_available():
|
||||
|
|
|
@ -189,9 +189,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep.",
|
||||
" one of the `scheduler.timesteps` as a timestep."
|
||||
),
|
||||
)
|
||||
|
||||
if not self.is_scale_input_called:
|
||||
|
|
|
@ -198,9 +198,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep.",
|
||||
" one of the `scheduler.timesteps` as a timestep."
|
||||
),
|
||||
)
|
||||
|
||||
if not self.is_scale_input_called:
|
||||
|
|
|
@ -537,8 +537,10 @@ class SchedulerCommonTest(unittest.TestCase):
|
|||
)
|
||||
self.assertTrue(
|
||||
hasattr(scheduler, "scale_model_input"),
|
||||
(
|
||||
f"{scheduler_class} does not implement a required class method `scale_model_input(sample,"
|
||||
" timestep)`",
|
||||
" timestep)`"
|
||||
),
|
||||
)
|
||||
self.assertTrue(
|
||||
hasattr(scheduler, "step"),
|
||||
|
|
Loading…
Reference in New Issue