* Fix typo in docstring.
* Allow dtype to be overridden on model load.
This may be a temporary solution until #567 is addressed.
* Create latents in float32
The denoising loop always computes the next step in float32, so this
would fail when using `bfloat16`.
* WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline
* todo comment
* Fix imports
* Fix imports
* add dummies
* Fix empty init
* make pipeline work
* up
* Use Flax schedulers (typing, docstring)
* Wrap model imports inside availability checks.
* more updates
* make sure flax is not broken
* make style
* more fixes
* up
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@latenitesoft.com>
* first commit:
- add `from_pt` argument in `from_pretrained` function
- add `modeling_flax_pytorch_utils.py` file
* small nit
- fix a small nit - to not enter in the second if condition
* major changes
- modify FlaxUnet modules
- first conversion script
- more keys to be matched
* keys match
- now all keys match
- change module names for correct matching
- upsample module name changed
* working v1
- test pass with atol and rtol= `4e-02`
* replace unsued arg
* make quality
* add small docstring
* add more comments
- add TODO for embedding layers
* small change
- use `jnp.expand_dims` for converting `timesteps` in case it is a 0-dimensional array
* add more conditions on conversion
- add better test to check for keys conversion
* make shapes consistent
- output `img_w x img_h x n_channels` from the VAE
* Revert "make shapes consistent"
This reverts commit 4cad1aeb4aeb224402dad13c018a5d42e96267f6.
* fix unet shape
- channels first!
* Unify offset configuration in DDIM and PNDM schedulers
* Format
Add missing variables
* Fix pipeline test
* Update src/diffusers/schedulers/scheduling_ddim.py
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* Default set_alpha_to_one to false
* Format
* Add tests
* Format
* add deprecation warning
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* Fix typos
* Add a typo check action
* Fix a bug
* Changed to manual typo check currently
Ref: https://github.com/huggingface/diffusers/pull/483#pullrequestreview-1104468010
Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
* Removed a confusing message
* Renamed "nin_shortcut" to "in_shortcut"
* Add memo about NIN
Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
* Fix is_onnx_available
Fix: If user install onnxruntime-gpu, is_onnx_available() will return False.
* add more onnxruntime candidates
* Run `make style`
Co-authored-by: anton-l <anton@huggingface.co>
* begin text2img conversion script
* add fn to convert config
* create config if not provided
* update imports and use UNet2DConditionModel
* fix imports, layer names
* fix unet coversion
* add function to convert VAE
* fix vae conversion
* update main
* create text model
* update config creating logic for unet
* fix config creation
* update script to create and save pipeline
* remove unused imports
* fix checkpoint loading
* better name
* save progress
* finish
* up
* up
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* First UNet Flax modeling blocks.
Mimic the structure of the PyTorch files.
The model classes themselves need work, depending on what we do about
configuration and initialization.
* Remove FlaxUNet2DConfig class.
* ignore_for_config non-config args.
* Implement `FlaxModelMixin`
* Use new mixins for Flax UNet.
For some reason the configuration is not correctly applied; the
signature of the `__init__` method does not contain all the parameters
by the time it's inspected in `extract_init_dict`.
* Import `FlaxUNet2DConditionModel` if flax is available.
* Rm unused method `framework`
* Update src/diffusers/modeling_flax_utils.py
Co-authored-by: Suraj Patil <surajp815@gmail.com>
* Indicate types in flax.struct.dataclass as pointed out by @mishig25
Co-authored-by: Mishig Davaadorj <mishig.davaadorj@coloradocollege.edu>
* Fix typo in transformer block.
* make style
* some more changes
* make style
* Add comment
* Update src/diffusers/modeling_flax_utils.py
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* Rm unneeded comment
* Update docstrings
* correct ignore kwargs
* make style
* Update docstring examples
* Make style
* Style: remove empty line.
* Apply style (after upgrading black from pinned version)
* Remove some commented code and unused imports.
* Add init_weights (not yet in use until #513).
* Trickle down deterministic to blocks.
* Rename q, k, v according to the latest PyTorch version.
Note that weights were exported with the old names, so we need to be
careful.
* Flax UNet docstrings, default props as in PyTorch.
* Fix minor typos in PyTorch docstrings.
* Use FlaxUNet2DConditionOutput as output from UNet.
* make style
Co-authored-by: Mishig Davaadorj <dmishig@gmail.com>
Co-authored-by: Mishig Davaadorj <mishig.davaadorj@coloradocollege.edu>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* beta never changes removed from state
* fix typos in docs
* removed unused var
* initial ddim flax scheduler
* import
* added dummy objects
* fix style
* fix typo
* docs
* fix typo in comment
* set return type
* added flax ddom
* fix style
* remake
* pass PRNG key as argument and split before use
* fix doc string
* use config
* added flax Karras VE scheduler
* make style
* fix dummy
* fix ndarray type annotation
* replace returns a new state
* added lms_discrete scheduler
* use self.config
* add_noise needs state
* use config
* use config
* docstring
* added flax score sde ve
* fix imports
* fix typos
* add different method for sliced attention
* Update src/diffusers/models/attention.py
* Apply suggestions from code review
* Update src/diffusers/models/attention.py
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>