make style

This commit is contained in:
Patrick von Platen 2023-01-31 11:46:48 +00:00
parent d1efefe15e
commit 60d915fbed
2 changed files with 4 additions and 4 deletions

View File

@ -45,7 +45,6 @@ def preprocess_image(image):
def preprocess_mask(mask, scale_factor=8):
if not isinstance(mask, torch.FloatTensor):
mask = mask.convert("L")
w, h = mask.size
@ -65,7 +64,8 @@ def preprocess_mask(mask, scale_factor=8):
mask = mask.permute(0, 3, 1, 2)
elif mask.shape[1] not in valid_mask_channel_sizes:
raise ValueError(
f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension, but received mask of shape {tuple(mask.shape)}"
f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension,"
f" but received mask of shape {tuple(mask.shape)}"
)
# (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
mask = mask.mean(dim=1, keepdim=True)