Flax: Trickle down `norm_num_groups` (#789)
* pass norm_num_groups param and add tests * set resnet_groups for FlaxUNetMidBlock2D * fixed docstrings * fixed typo * using is_flax_available util and created require_flax decorator
This commit is contained in:
parent
66a5279a94
commit
a124204490
|
@ -119,6 +119,8 @@ class FlaxResnetBlock2D(nn.Module):
|
|||
Output channels
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate
|
||||
groups (:obj:`int`, *optional*, defaults to `32`):
|
||||
The number of groups to use for group norm.
|
||||
use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
|
||||
Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
|
@ -128,13 +130,14 @@ class FlaxResnetBlock2D(nn.Module):
|
|||
in_channels: int
|
||||
out_channels: int = None
|
||||
dropout: float = 0.0
|
||||
groups: int = 32
|
||||
use_nin_shortcut: bool = None
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
out_channels = self.in_channels if self.out_channels is None else self.out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
||||
self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
|
||||
self.conv1 = nn.Conv(
|
||||
out_channels,
|
||||
kernel_size=(3, 3),
|
||||
|
@ -143,7 +146,7 @@ class FlaxResnetBlock2D(nn.Module):
|
|||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
||||
self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
|
||||
self.dropout_layer = nn.Dropout(self.dropout)
|
||||
self.conv2 = nn.Conv(
|
||||
out_channels,
|
||||
|
@ -191,12 +194,15 @@ class FlaxAttentionBlock(nn.Module):
|
|||
Input channels
|
||||
num_head_channels (:obj:`int`, *optional*, defaults to `None`):
|
||||
Number of attention heads
|
||||
num_groups (:obj:`int`, *optional*, defaults to `32`):
|
||||
The number of groups to use for group norm
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
|
||||
"""
|
||||
channels: int
|
||||
num_head_channels: int = None
|
||||
num_groups: int = 32
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
|
@ -204,7 +210,7 @@ class FlaxAttentionBlock(nn.Module):
|
|||
|
||||
dense = partial(nn.Dense, self.channels, dtype=self.dtype)
|
||||
|
||||
self.group_norm = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
||||
self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6)
|
||||
self.query, self.key, self.value = dense(), dense(), dense()
|
||||
self.proj_attn = dense()
|
||||
|
||||
|
@ -264,6 +270,8 @@ class FlaxDownEncoderBlock2D(nn.Module):
|
|||
Dropout rate
|
||||
num_layers (:obj:`int`, *optional*, defaults to 1):
|
||||
Number of Resnet layer block
|
||||
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
|
||||
The number of groups to use for the Resnet block group norm
|
||||
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
|
||||
Whether to add downsample layer
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
|
@ -273,6 +281,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
|
|||
out_channels: int
|
||||
dropout: float = 0.0
|
||||
num_layers: int = 1
|
||||
resnet_groups: int = 32
|
||||
add_downsample: bool = True
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
|
@ -285,6 +294,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
|
|||
in_channels=in_channels,
|
||||
out_channels=self.out_channels,
|
||||
dropout=self.dropout,
|
||||
groups=self.resnet_groups,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
resnets.append(res_block)
|
||||
|
@ -303,9 +313,9 @@ class FlaxDownEncoderBlock2D(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
class FlaxUpEncoderBlock2D(nn.Module):
|
||||
class FlaxUpDecoderBlock2D(nn.Module):
|
||||
r"""
|
||||
Flax Resnet blocks-based Encoder block for diffusion-based VAE.
|
||||
Flax Resnet blocks-based Decoder block for diffusion-based VAE.
|
||||
|
||||
Parameters:
|
||||
in_channels (:obj:`int`):
|
||||
|
@ -316,8 +326,10 @@ class FlaxUpEncoderBlock2D(nn.Module):
|
|||
Dropout rate
|
||||
num_layers (:obj:`int`, *optional*, defaults to 1):
|
||||
Number of Resnet layer block
|
||||
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
|
||||
Whether to add downsample layer
|
||||
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
|
||||
The number of groups to use for the Resnet block group norm
|
||||
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
|
||||
Whether to add upsample layer
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
|
@ -325,6 +337,7 @@ class FlaxUpEncoderBlock2D(nn.Module):
|
|||
out_channels: int
|
||||
dropout: float = 0.0
|
||||
num_layers: int = 1
|
||||
resnet_groups: int = 32
|
||||
add_upsample: bool = True
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
|
@ -336,6 +349,7 @@ class FlaxUpEncoderBlock2D(nn.Module):
|
|||
in_channels=in_channels,
|
||||
out_channels=self.out_channels,
|
||||
dropout=self.dropout,
|
||||
groups=self.resnet_groups,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
resnets.append(res_block)
|
||||
|
@ -366,6 +380,8 @@ class FlaxUNetMidBlock2D(nn.Module):
|
|||
Dropout rate
|
||||
num_layers (:obj:`int`, *optional*, defaults to 1):
|
||||
Number of Resnet layer block
|
||||
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
|
||||
The number of groups to use for the Resnet and Attention block group norm
|
||||
attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`):
|
||||
Number of attention heads for each attention block
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
|
@ -374,16 +390,20 @@ class FlaxUNetMidBlock2D(nn.Module):
|
|||
in_channels: int
|
||||
dropout: float = 0.0
|
||||
num_layers: int = 1
|
||||
resnet_groups: int = 32
|
||||
attn_num_head_channels: int = 1
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
FlaxResnetBlock2D(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.in_channels,
|
||||
dropout=self.dropout,
|
||||
groups=resnet_groups,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
]
|
||||
|
@ -392,7 +412,10 @@ class FlaxUNetMidBlock2D(nn.Module):
|
|||
|
||||
for _ in range(self.num_layers):
|
||||
attn_block = FlaxAttentionBlock(
|
||||
channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype
|
||||
channels=self.in_channels,
|
||||
num_head_channels=self.attn_num_head_channels,
|
||||
num_groups=resnet_groups,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
attentions.append(attn_block)
|
||||
|
||||
|
@ -400,6 +423,7 @@ class FlaxUNetMidBlock2D(nn.Module):
|
|||
in_channels=self.in_channels,
|
||||
out_channels=self.in_channels,
|
||||
dropout=self.dropout,
|
||||
groups=resnet_groups,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
resnets.append(res_block)
|
||||
|
@ -441,7 +465,7 @@ class FlaxEncoder(nn.Module):
|
|||
Tuple containing the number of output channels for each block
|
||||
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
|
||||
Number of Resnet layer for each block
|
||||
norm_num_groups (:obj:`int`, *optional*, defaults to `2`):
|
||||
norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
|
||||
norm num group
|
||||
act_fn (:obj:`str`, *optional*, defaults to `silu`):
|
||||
Activation function
|
||||
|
@ -483,6 +507,7 @@ class FlaxEncoder(nn.Module):
|
|||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=self.layers_per_block,
|
||||
resnet_groups=self.norm_num_groups,
|
||||
add_downsample=not is_final_block,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
@ -491,12 +516,15 @@ class FlaxEncoder(nn.Module):
|
|||
|
||||
# middle
|
||||
self.mid_block = FlaxUNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
|
||||
in_channels=block_out_channels[-1],
|
||||
resnet_groups=self.norm_num_groups,
|
||||
attn_num_head_channels=None,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# end
|
||||
conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels
|
||||
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
||||
self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
|
||||
self.conv_out = nn.Conv(
|
||||
conv_out_channels,
|
||||
kernel_size=(3, 3),
|
||||
|
@ -581,7 +609,10 @@ class FlaxDecoder(nn.Module):
|
|||
|
||||
# middle
|
||||
self.mid_block = FlaxUNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
|
||||
in_channels=block_out_channels[-1],
|
||||
resnet_groups=self.norm_num_groups,
|
||||
attn_num_head_channels=None,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
|
@ -594,10 +625,11 @@ class FlaxDecoder(nn.Module):
|
|||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = FlaxUpEncoderBlock2D(
|
||||
up_block = FlaxUpDecoderBlock2D(
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=self.layers_per_block + 1,
|
||||
resnet_groups=self.norm_num_groups,
|
||||
add_upsample=not is_final_block,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
@ -607,7 +639,7 @@ class FlaxDecoder(nn.Module):
|
|||
self.up_blocks = up_blocks
|
||||
|
||||
# end
|
||||
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
||||
self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
|
||||
self.conv_out = nn.Conv(
|
||||
self.out_channels,
|
||||
kernel_size=(3, 3),
|
||||
|
|
|
@ -14,6 +14,8 @@ import PIL.ImageOps
|
|||
import requests
|
||||
from packaging import version
|
||||
|
||||
from .import_utils import is_flax_available
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
@ -89,6 +91,13 @@ def slow(test_case):
|
|||
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
||||
|
||||
|
||||
def require_flax(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
|
||||
"""
|
||||
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
|
||||
|
||||
|
||||
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
|
||||
"""
|
||||
Args:
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
from diffusers.utils import is_flax_available
|
||||
from diffusers.utils.testing_utils import require_flax
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxModelTesterMixin:
|
||||
def test_output(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
|
||||
jax.lax.stop_gradient(variables)
|
||||
|
||||
output = model.apply(variables, inputs_dict["sample"])
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["norm_num_groups"] = 16
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
|
||||
jax.lax.stop_gradient(variables)
|
||||
|
||||
output = model.apply(variables, inputs_dict["sample"])
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
|
@ -0,0 +1,39 @@
|
|||
import unittest
|
||||
|
||||
from diffusers import FlaxAutoencoderKL
|
||||
from diffusers.utils import is_flax_available
|
||||
from diffusers.utils.testing_utils import require_flax
|
||||
|
||||
from .test_modeling_common_flax import FlaxModelTesterMixin
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase):
|
||||
model_class = FlaxAutoencoderKL
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
prng_key = jax.random.PRNGKey(0)
|
||||
image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes))
|
||||
|
||||
return {"sample": image, "prng_key": prng_key}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": [32, 64],
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
"latent_channels": 4,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
Loading…
Reference in New Issue