[flax safety checker] Use `FlaxPreTrainedModel` for saving/loading (#591)
* use FlaxPreTrainedModel for flax safety module * fix name * fix one more * Apply suggestions from code review
This commit is contained in:
parent
8a6833b85c
commit
c6629e6f11
|
@ -1,4 +1,5 @@
|
|||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -6,13 +7,9 @@ import jax
|
|||
import jax.numpy as jnp
|
||||
from flax import linen as nn
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.struct import field
|
||||
from transformers import CLIPVisionConfig
|
||||
from transformers import CLIPConfig, FlaxPreTrainedModel
|
||||
from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule
|
||||
|
||||
from ...configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ...modeling_flax_utils import FlaxModelMixin
|
||||
|
||||
|
||||
def jax_cosine_distance(emb_1, emb_2, eps=1e-12):
|
||||
norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T
|
||||
|
@ -20,34 +17,17 @@ def jax_cosine_distance(emb_1, emb_2, eps=1e-12):
|
|||
return jnp.matmul(norm_emb_1, norm_emb_2.T)
|
||||
|
||||
|
||||
@flax_register_to_config
|
||||
class FlaxStableDiffusionSafetyChecker(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
projection_dim: int = 768
|
||||
# CLIPVisionConfig fields
|
||||
vision_config: dict = field(default_factory=dict)
|
||||
class FlaxStableDiffusionSafetyCheckerModule(nn.Module):
|
||||
config: CLIPConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
|
||||
# init input tensor
|
||||
input_shape = (
|
||||
1,
|
||||
self.vision_config["image_size"],
|
||||
self.vision_config["image_size"],
|
||||
self.vision_config["num_channels"],
|
||||
)
|
||||
pixel_values = jax.random.normal(rng, input_shape)
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
return self.init(rngs, pixel_values)["params"]
|
||||
|
||||
def setup(self):
|
||||
clip_vision_config = CLIPVisionConfig(**self.vision_config)
|
||||
self.vision_model = FlaxCLIPVisionModule(clip_vision_config, dtype=self.dtype)
|
||||
self.visual_projection = nn.Dense(self.projection_dim, use_bias=False, dtype=self.dtype)
|
||||
self.vision_model = FlaxCLIPVisionModule(self.config.vision_config)
|
||||
self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype)
|
||||
|
||||
self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.projection_dim))
|
||||
self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim))
|
||||
self.special_care_embeds = self.param(
|
||||
"special_care_embeds", jax.nn.initializers.ones, (3, self.projection_dim)
|
||||
"special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim)
|
||||
)
|
||||
|
||||
self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,))
|
||||
|
@ -109,3 +89,59 @@ class FlaxStableDiffusionSafetyChecker(nn.Module, FlaxModelMixin, ConfigMixin):
|
|||
)
|
||||
|
||||
return images, has_nsfw_concepts
|
||||
|
||||
|
||||
class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
|
||||
config_class = CLIPConfig
|
||||
main_input_name = "clip_input"
|
||||
module_class = FlaxStableDiffusionSafetyCheckerModule
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CLIPConfig,
|
||||
input_shape: Optional[Tuple] = None,
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if input_shape is None:
|
||||
input_shape = (1, 224, 224, 3)
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensor
|
||||
clip_input = jax.random.normal(rng, input_shape)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
random_params = self.module.init(rngs, clip_input)["params"]
|
||||
|
||||
return random_params
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_input,
|
||||
params: dict = None,
|
||||
):
|
||||
clip_input = jnp.transpose(clip_input, (0, 2, 3, 1))
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(clip_input, dtype=jnp.float32),
|
||||
rngs={},
|
||||
)
|
||||
|
||||
def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict = None):
|
||||
def _filtered_with_scores(module, special_cos_dist, cos_dist, images):
|
||||
return module.filtered_with_scores(special_cos_dist, cos_dist, images)
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
special_cos_dist,
|
||||
cos_dist,
|
||||
images,
|
||||
method=_filtered_with_scores,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue