[flax safety checker] Use `FlaxPreTrainedModel` for saving/loading (#591)

* use FlaxPreTrainedModel for flax safety module

* fix name

* fix one more

* Apply suggestions from code review
This commit is contained in:
Suraj Patil 2022-09-20 20:11:32 +02:00 committed by GitHub
parent 8a6833b85c
commit c6629e6f11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 64 additions and 28 deletions

View File

@ -1,4 +1,5 @@
import warnings
from typing import Optional, Tuple
import numpy as np
@ -6,13 +7,9 @@ import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.struct import field
from transformers import CLIPVisionConfig
from transformers import CLIPConfig, FlaxPreTrainedModel
from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule
from ...configuration_utils import ConfigMixin, flax_register_to_config
from ...modeling_flax_utils import FlaxModelMixin
def jax_cosine_distance(emb_1, emb_2, eps=1e-12):
norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T
@ -20,34 +17,17 @@ def jax_cosine_distance(emb_1, emb_2, eps=1e-12):
return jnp.matmul(norm_emb_1, norm_emb_2.T)
@flax_register_to_config
class FlaxStableDiffusionSafetyChecker(nn.Module, FlaxModelMixin, ConfigMixin):
projection_dim: int = 768
# CLIPVisionConfig fields
vision_config: dict = field(default_factory=dict)
class FlaxStableDiffusionSafetyCheckerModule(nn.Module):
config: CLIPConfig
dtype: jnp.dtype = jnp.float32
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
# init input tensor
input_shape = (
1,
self.vision_config["image_size"],
self.vision_config["image_size"],
self.vision_config["num_channels"],
)
pixel_values = jax.random.normal(rng, input_shape)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.init(rngs, pixel_values)["params"]
def setup(self):
clip_vision_config = CLIPVisionConfig(**self.vision_config)
self.vision_model = FlaxCLIPVisionModule(clip_vision_config, dtype=self.dtype)
self.visual_projection = nn.Dense(self.projection_dim, use_bias=False, dtype=self.dtype)
self.vision_model = FlaxCLIPVisionModule(self.config.vision_config)
self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype)
self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.projection_dim))
self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim))
self.special_care_embeds = self.param(
"special_care_embeds", jax.nn.initializers.ones, (3, self.projection_dim)
"special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim)
)
self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,))
@ -109,3 +89,59 @@ class FlaxStableDiffusionSafetyChecker(nn.Module, FlaxModelMixin, ConfigMixin):
)
return images, has_nsfw_concepts
class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
config_class = CLIPConfig
main_input_name = "clip_input"
module_class = FlaxStableDiffusionSafetyCheckerModule
def __init__(
self,
config: CLIPConfig,
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
if input_shape is None:
input_shape = (1, 224, 224, 3)
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor
clip_input = jax.random.normal(rng, input_shape)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(rngs, clip_input)["params"]
return random_params
def __call__(
self,
clip_input,
params: dict = None,
):
clip_input = jnp.transpose(clip_input, (0, 2, 3, 1))
return self.module.apply(
{"params": params or self.params},
jnp.array(clip_input, dtype=jnp.float32),
rngs={},
)
def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict = None):
def _filtered_with_scores(module, special_cos_dist, cos_dist, images):
return module.filtered_with_scores(special_cos_dist, cos_dist, images)
return self.module.apply(
{"params": params or self.params},
special_cos_dist,
cos_dist,
images,
method=_filtered_with_scores,
)