make style
This commit is contained in:
parent
d1efefe15e
commit
60d915fbed
|
@ -45,7 +45,6 @@ def preprocess_image(image):
|
|||
|
||||
|
||||
def preprocess_mask(mask, scale_factor=8):
|
||||
|
||||
if not isinstance(mask, torch.FloatTensor):
|
||||
mask = mask.convert("L")
|
||||
w, h = mask.size
|
||||
|
@ -65,7 +64,8 @@ def preprocess_mask(mask, scale_factor=8):
|
|||
mask = mask.permute(0, 3, 1, 2)
|
||||
elif mask.shape[1] not in valid_mask_channel_sizes:
|
||||
raise ValueError(
|
||||
f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension, but received mask of shape {tuple(mask.shape)}"
|
||||
f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension,"
|
||||
f" but received mask of shape {tuple(mask.shape)}"
|
||||
)
|
||||
# (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
|
||||
mask = mask.mean(dim=1, keepdim=True)
|
||||
|
|
Loading…
Reference in New Issue