diffusers/tests
Kashif Rasul 37d113cce7
DiT Pipeline (#1806)
* added dit model

* import

* initial pipeline

* initial convert script

* initial pipeline

* make style

* raise valueerror

* single function

* rename classes

* use DDIMScheduler

* timesteps embedder

* samples to cpu

* fix var names

* fix numpy type

* use timesteps class for proj

* fix typo

* fix arg name

* flip_sin_to_cos and better var names

* fix C shape cal

* make style

* remove unused imports

* cleanup

* add back patch_size

* initial dit doc

* typo

* Update docs/source/api/pipelines/dit.mdx

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* added copyright license headers

* added example usage and toc

* fix variable names asserts

* remove comment

* added docs

* fix typo

* upstream changes

* set proper device for drop_ids

* added initial dit pipeline test

* update docs

* fix imports

* make fix-copies

* isort

* fix imports

* get rid of more magic numbers

* fix code when guidance is off

* remove block_kwargs

* cleanup script

* removed to_2tuple

* use FeedForward class instead of another MLP

* style

* work on mergint DiTBlock with BasicTransformerBlock

* added missing final_dropout and args to BasicTransformerBlock

* use norm from block

* fix arg

* remove unused arg

* fix call to class_embedder

* use timesteps

* make style

* attn_output gets multiplied

* removed commented code

* use Transformer2D

* use self.is_input_patches

* fix flags

* fixed conversion to use Transformer2DModel

* fixes for pipeline

* remove dit.py

* fix timesteps device

* use randn_tensor and fix fp16 inf.

* timesteps_emb already the right dtype

* fix dit test class

* fix test and style

* fix norm2 usage in vq-diffusion

* added author names to pipeline and lmagenet labels link

* fix tests

* use norm_type as string

* rename dit to transformer

* fix name

* fix test

* set  norm_type = "layer" by default

* fix tests

* do not skip common tests

* Update src/diffusers/models/attention.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* revert AdaLayerNorm API

* fix norm_type name

* make sure all components are in eval mode

* revert norm2 API

* compact

* finish deprecation

* add slow tests

* remove @

* refactor some stuff

* upload

* Update src/diffusers/pipelines/dit/pipeline_dit.py

* finish more

* finish docs

* improve docs

* finish docs

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: William Berman <WLBberman@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-01-17 23:09:29 +01:00
..
fixtures/custom_pipeline Make repo structure consistent (#1862) 2022-12-30 11:51:08 +01:00
models Make repo structure consistent (#1862) 2022-12-30 11:51:08 +01:00
pipelines DiT Pipeline (#1806) 2023-01-17 23:09:29 +01:00
repo_utils Move accelerate to a soft-dependency (#1134) 2022-11-04 14:58:52 +01:00
__init__.py fix issues with loading, add test for pipeline 2022-06-07 15:40:36 +02:00
conftest.py [Utils] Add deprecate function and move testing_utils under utils (#659) 2022-10-03 23:44:24 +02:00
test_config.py Bump to 0.12.0.dev0 (#1771) 2022-12-19 18:44:08 +01:00
test_layers_utils.py Test ResnetBlock2D (#1850) 2023-01-04 22:57:32 +01:00
test_modeling_common.py Make repo structure consistent (#1862) 2022-12-30 11:51:08 +01:00
test_modeling_common_flax.py Allow to set config params directly in init (#1419) 2022-11-25 15:07:09 +01:00
test_outputs.py Fix BaseOutput initialization from dict (#570) 2022-09-20 18:32:16 +02:00
test_pipelines.py Allow converting Flax to PyTorch by adding a "from_flax" keyword (#1900) 2023-01-12 20:00:35 +01:00
test_pipelines_common.py DiT Pipeline (#1806) 2023-01-17 23:09:29 +01:00
test_pipelines_flax.py Make height and width optional (#1401) 2022-11-24 18:23:59 +01:00
test_pipelines_onnx_common.py Reorganize pipeline tests (#963) 2022-10-24 16:34:01 +02:00
test_scheduler.py [Black] Update black library (#2007) 2023-01-16 15:16:28 +01:00
test_scheduler_flax.py [Flax] Stateless schedulers, fixes and refactors (#1661) 2022-12-20 01:42:41 +01:00
test_training.py [Utils] Add deprecate function and move testing_utils under utils (#659) 2022-10-03 23:44:24 +02:00
test_unet_2d_blocks.py Add tests for 2D UNet blocks (#1945) 2023-01-16 12:53:05 +01:00
test_unet_blocks_common.py Add tests for 2D UNet blocks (#1945) 2023-01-16 12:53:05 +01:00
test_utils.py [Deprecate] Correct stacklevel (#1483) 2022-12-01 16:28:10 +01:00