[flax safety checker] Use `FlaxPreTrainedModel` for saving/loading (#591)
* use FlaxPreTrainedModel for flax safety module * fix name * fix one more * Apply suggestions from code review
This commit is contained in:
parent
8a6833b85c
commit
c6629e6f11
|
@ -1,4 +1,5 @@
|
||||||
import warnings
|
import warnings
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -6,13 +7,9 @@ import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax import linen as nn
|
from flax import linen as nn
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict
|
||||||
from flax.struct import field
|
from transformers import CLIPConfig, FlaxPreTrainedModel
|
||||||
from transformers import CLIPVisionConfig
|
|
||||||
from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule
|
from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule
|
||||||
|
|
||||||
from ...configuration_utils import ConfigMixin, flax_register_to_config
|
|
||||||
from ...modeling_flax_utils import FlaxModelMixin
|
|
||||||
|
|
||||||
|
|
||||||
def jax_cosine_distance(emb_1, emb_2, eps=1e-12):
|
def jax_cosine_distance(emb_1, emb_2, eps=1e-12):
|
||||||
norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T
|
norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T
|
||||||
|
@ -20,34 +17,17 @@ def jax_cosine_distance(emb_1, emb_2, eps=1e-12):
|
||||||
return jnp.matmul(norm_emb_1, norm_emb_2.T)
|
return jnp.matmul(norm_emb_1, norm_emb_2.T)
|
||||||
|
|
||||||
|
|
||||||
@flax_register_to_config
|
class FlaxStableDiffusionSafetyCheckerModule(nn.Module):
|
||||||
class FlaxStableDiffusionSafetyChecker(nn.Module, FlaxModelMixin, ConfigMixin):
|
config: CLIPConfig
|
||||||
projection_dim: int = 768
|
|
||||||
# CLIPVisionConfig fields
|
|
||||||
vision_config: dict = field(default_factory=dict)
|
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
|
|
||||||
# init input tensor
|
|
||||||
input_shape = (
|
|
||||||
1,
|
|
||||||
self.vision_config["image_size"],
|
|
||||||
self.vision_config["image_size"],
|
|
||||||
self.vision_config["num_channels"],
|
|
||||||
)
|
|
||||||
pixel_values = jax.random.normal(rng, input_shape)
|
|
||||||
params_rng, dropout_rng = jax.random.split(rng)
|
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
|
||||||
return self.init(rngs, pixel_values)["params"]
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
clip_vision_config = CLIPVisionConfig(**self.vision_config)
|
self.vision_model = FlaxCLIPVisionModule(self.config.vision_config)
|
||||||
self.vision_model = FlaxCLIPVisionModule(clip_vision_config, dtype=self.dtype)
|
self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype)
|
||||||
self.visual_projection = nn.Dense(self.projection_dim, use_bias=False, dtype=self.dtype)
|
|
||||||
|
|
||||||
self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.projection_dim))
|
self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim))
|
||||||
self.special_care_embeds = self.param(
|
self.special_care_embeds = self.param(
|
||||||
"special_care_embeds", jax.nn.initializers.ones, (3, self.projection_dim)
|
"special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,))
|
self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,))
|
||||||
|
@ -109,3 +89,59 @@ class FlaxStableDiffusionSafetyChecker(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
return images, has_nsfw_concepts
|
return images, has_nsfw_concepts
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
|
||||||
|
config_class = CLIPConfig
|
||||||
|
main_input_name = "clip_input"
|
||||||
|
module_class = FlaxStableDiffusionSafetyCheckerModule
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: CLIPConfig,
|
||||||
|
input_shape: Optional[Tuple] = None,
|
||||||
|
seed: int = 0,
|
||||||
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if input_shape is None:
|
||||||
|
input_shape = (1, 224, 224, 3)
|
||||||
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
|
# init input tensor
|
||||||
|
clip_input = jax.random.normal(rng, input_shape)
|
||||||
|
|
||||||
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
|
random_params = self.module.init(rngs, clip_input)["params"]
|
||||||
|
|
||||||
|
return random_params
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_input,
|
||||||
|
params: dict = None,
|
||||||
|
):
|
||||||
|
clip_input = jnp.transpose(clip_input, (0, 2, 3, 1))
|
||||||
|
|
||||||
|
return self.module.apply(
|
||||||
|
{"params": params or self.params},
|
||||||
|
jnp.array(clip_input, dtype=jnp.float32),
|
||||||
|
rngs={},
|
||||||
|
)
|
||||||
|
|
||||||
|
def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict = None):
|
||||||
|
def _filtered_with_scores(module, special_cos_dist, cos_dist, images):
|
||||||
|
return module.filtered_with_scores(special_cos_dist, cos_dist, images)
|
||||||
|
|
||||||
|
return self.module.apply(
|
||||||
|
{"params": params or self.params},
|
||||||
|
special_cos_dist,
|
||||||
|
cos_dist,
|
||||||
|
images,
|
||||||
|
method=_filtered_with_scores,
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue