* docs: `src/diffusers` readability improvements
Signed-off-by: Ryan Russell <git@ryanrussell.org>
* docs: `make style` lint
Signed-off-by: Ryan Russell <git@ryanrussell.org>
Signed-off-by: Ryan Russell <git@ryanrussell.org>
* refactor: pipelines readability improvements
Signed-off-by: Ryan Russell <git@ryanrussell.org>
* docs: remove todo comment from flax pipeline
Signed-off-by: Ryan Russell <git@ryanrussell.org>
Signed-off-by: Ryan Russell <git@ryanrussell.org>
* Adding pred_original_sample to SchedulerOutput of DDPMScheduler, DDIMScheduler, LMSDiscreteScheduler, KarrasVeScheduler step methods so we can access the predicted denoised outputs
* Gave DDPMScheduler, DDIMScheduler and LMSDiscreteScheduler their own output dataclasses so the default SchedulerOutput in scheduling_utils does not need pred_original_sample as an optional extra
* Reordered library imports to follow standard
* didnt get import order quite right apparently
* Forgot to change name of LMSDiscreteSchedulerOutput
* Aha, needed some extra libs for make style to fully work
* add grad ckpt to downsample blocks
* make it work
* don't pass gradient_checkpointing to upsample block
* add tests for UNet2DConditionModel
* add test_gradient_checkpointing
* add gradient_checkpointing for up and down blocks
* add functions to enable and disable grad ckpt
* remove the forward argument
* better naming
* make supports_gradient_checkpointing private
* Optionally return state in from_config.
Useful for Flax schedulers.
* has_state is now a property, make check more strict.
I don't check the class is `SchedulerMixin` to prevent circular
dependencies. It should be enough that the class name starts with "Flax"
the object declares it "has_state" and the "create_state" exists too.
* Use state in pipeline from_pretrained.
* Make style
* Fix typo in docstring.
* Allow dtype to be overridden on model load.
This may be a temporary solution until #567 is addressed.
* Create latents in float32
The denoising loop always computes the next step in float32, so this
would fail when using `bfloat16`.
* WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline
* todo comment
* Fix imports
* Fix imports
* add dummies
* Fix empty init
* make pipeline work
* up
* Use Flax schedulers (typing, docstring)
* Wrap model imports inside availability checks.
* more updates
* make sure flax is not broken
* make style
* more fixes
* up
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@latenitesoft.com>
* first commit:
- add `from_pt` argument in `from_pretrained` function
- add `modeling_flax_pytorch_utils.py` file
* small nit
- fix a small nit - to not enter in the second if condition
* major changes
- modify FlaxUnet modules
- first conversion script
- more keys to be matched
* keys match
- now all keys match
- change module names for correct matching
- upsample module name changed
* working v1
- test pass with atol and rtol= `4e-02`
* replace unsued arg
* make quality
* add small docstring
* add more comments
- add TODO for embedding layers
* small change
- use `jnp.expand_dims` for converting `timesteps` in case it is a 0-dimensional array
* add more conditions on conversion
- add better test to check for keys conversion
* make shapes consistent
- output `img_w x img_h x n_channels` from the VAE
* Revert "make shapes consistent"
This reverts commit 4cad1aeb4aeb224402dad13c018a5d42e96267f6.
* fix unet shape
- channels first!